diff options
Diffstat (limited to 'src')
20 files changed, 363 insertions, 81 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp index 50fc5bdb8a..58e4861bc0 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp @@ -33,21 +33,21 @@ #include "kernels/a32_sgemm_8x6.hpp" -#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/a64_ffhybrid_bf16fp32_mmla_6x16.hpp" #include "kernels/a64_ffinterleaved_bf16fp32_dot_8x12.hpp" #include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp" -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/a64_hybrid_bf16fp32_dot_6x16.hpp" #include "kernels/a64_hybrid_bf16fp32_mmla_6x16.hpp" #include "kernels/a64_interleaved_bf16fp32_dot_8x12.hpp" #include "kernels/a64_interleaved_bf16fp32_mmla_8x12.hpp" #include "kernels/a64_sgemm_8x12.hpp" -#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/sve_ffhybrid_bf16fp32_mmla_6x4VL.hpp" #include "kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp" -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/sve_hybrid_bf16fp32_dot_6x4VL.hpp" #include "kernels/sve_hybrid_bf16fp32_mmla_6x4VL.hpp" #include "kernels/sve_interleaved_bf16fp32_dot_8x3VL.hpp" @@ -89,7 +89,7 @@ GemmImplementation<bfloat16, float>::with_estimate( [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_bf16fp32_dot_8x3VL, bfloat16, float>::estimate_cycles<bfloat16>(args); }, [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_bf16fp32_dot_8x3VL, bfloat16, float>(args); } ), -#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS GemmImplementation<bfloat16, float>::with_estimate( GemmMethod::GEMM_INTERLEAVED, "sve_ffinterleaved_bf16fp32_mmla_8x3VL", @@ -106,7 +106,7 @@ GemmImplementation<bfloat16, float>::with_estimate( [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_bf16fp32_mmla_6x4VL, bfloat16, float>::estimate_cycles<bfloat16>(args); }, [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_bf16fp32_mmla_6x4VL, bfloat16, float>(args); } ), -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #endif // ARM_COMPUTE_ENABLE_SVE GemmImplementation<bfloat16, float>::with_estimate( GemmMethod::GEMM_HYBRID, @@ -136,7 +136,7 @@ GemmImplementation<bfloat16, float>::with_estimate( [](const GemmArgs &args) { return GemmInterleaved<cls_a64_interleaved_bf16fp32_dot_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); }, [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_interleaved_bf16fp32_dot_8x12, bfloat16, float>(args); } ), -#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS GemmImplementation<bfloat16, float>::with_estimate( GemmMethod::GEMM_INTERLEAVED, "a64_ffinterleaved_bf16fp32_mmla_8x12", @@ -161,7 +161,7 @@ GemmImplementation<bfloat16, float>::with_estimate( [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_dot_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); }, [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_dot_8x12, bfloat16, float>(args); } ), -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS GemmImplementation<bfloat16, float>::with_estimate( GemmMethod::GEMM_INTERLEAVED, "a64_sgemm_8x12", @@ -197,7 +197,7 @@ const GemmImplementation<bfloat16, float> *gemm_implementation_list<bfloat16, fl /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<bfloat16, float> gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<bfloat16, float, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); template KernelDescription get_gemm_method<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &); template std::vector<KernelDescription> get_compatible_kernels<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &); diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp index 2796b0d204..d749dce98d 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp @@ -34,17 +34,17 @@ #include "gemm_interleaved.hpp" #include "kernels/a32_sgemm_8x6.hpp" -#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/a64_ffhybrid_fp16_mla_6x32.hpp" #include "kernels/a64_ffinterleaved_fp16_mla_8x24.hpp" -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/a64_hgemm_8x24.hpp" #include "kernels/a64_hybrid_fp16_mla_6x32.hpp" #include "kernels/a64_sgemm_8x12.hpp" -#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/sve_ffhybrid_fp16_mla_6x4VL.hpp" #include "kernels/sve_ffinterleaved_fp16_mla_8x3VL.hpp" -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/sve_hybrid_fp16_mla_6x4VL.hpp" #include "kernels/sve_interleaved_fp16_mla_8x3VL.hpp" @@ -66,7 +66,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate( [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); }, [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>(args); } ), -#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS GemmImplementation<__fp16, __fp16>::with_estimate( GemmMethod::GEMM_INTERLEAVED, "sve_ffinterleaved_fp16_mla_8x3VL", @@ -83,7 +83,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate( [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp16_mla_6x4VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); }, [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp16_mla_6x4VL, __fp16, __fp16>(args); } ), -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #endif // ARM_COMPUTE_ENABLE_SVE #if defined(__aarch64__) GemmImplementation<__fp16, __fp16>::with_estimate( @@ -100,7 +100,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate( [](const GemmArgs &args) { return GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>::estimate_cycles<__fp16>(args); }, [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>(args); } ), -#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS GemmImplementation<__fp16, __fp16>::with_estimate( GemmMethod::GEMM_INTERLEAVED, "a64_ffinterleaved_fp16_mla_8x24", @@ -117,7 +117,7 @@ GemmImplementation<__fp16, __fp16>::with_estimate( [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp16_mla_6x32, __fp16, __fp16>::estimate_cycles<__fp16>(args); }, [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp16_mla_6x32, __fp16, __fp16>(args); } ), -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS { GemmMethod::GEMM_INTERLEAVED, "a64_sgemm_8x12", @@ -152,7 +152,7 @@ const GemmImplementation<__fp16, __fp16> *gemm_implementation_list<__fp16, __fp1 /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<__fp16, __fp16, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); template KernelDescription get_gemm_method<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); template std::vector<KernelDescription> get_compatible_kernels<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index 4f7e191fb3..0fc9e8b912 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -31,12 +31,12 @@ #include "gemv_pretransposed.hpp" #include "kernels/a32_sgemm_8x6.hpp" -#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/a64_ffhybrid_fp32_mla_6x16.hpp" #include "kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp" #include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp" #include "kernels/a64_ffinterleaved_fp32_mla_8x12.hpp" -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/a64_hybrid_fp32bf16fp32_mmla_4x24.hpp" #include "kernels/a64_hybrid_fp32bf16fp32_mmla_6x16.hpp" #include "kernels/a64_hybrid_fp32_mla_4x24.hpp" @@ -48,12 +48,12 @@ #include "kernels/a64_smallK_hybrid_fp32_mla_6x4.hpp" #include "kernels/a64_smallK_hybrid_fp32_mla_8x4.hpp" -#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/sve_ffhybrid_fp32_mla_6x4VL.hpp" #include "kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp" #include "kernels/sve_ffinterleaved_fp32_mla_8x3VL.hpp" #include "kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp" -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/sve_hybrid_fp32bf16fp32_mmla_4x6VL.hpp" #include "kernels/sve_hybrid_fp32bf16fp32_mmla_6x4VL.hpp" #include "kernels/sve_hybrid_fp32_mla_6x4VL.hpp" @@ -165,7 +165,7 @@ GemmImplementation<float, float>::with_estimate( [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_fp32_mla_8x3VL, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp32_mla_8x3VL, float, float>(args); } ), - #ifdef ENABLE_FIXED_FORMAT_KERNELS + #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #ifdef ARM_COMPUTE_ENABLE_BF16 GemmImplementation<float, float>::with_estimate( GemmMethod::GEMM_INTERLEAVED, @@ -200,7 +200,7 @@ GemmImplementation<float, float>::with_estimate( [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32_mla_6x4VL, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32_mla_6x4VL, float, float>(args); } ), -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #endif // ARM_COMPUTE_ENABLE_SVE // Cortex-A35 specific kernel - use for any problem on A35, and never in any other cases. { @@ -253,7 +253,7 @@ GemmImplementation<float, float>::with_estimate( [](const GemmArgs &args) { return GemmInterleaved<cls_a64_sgemm_8x12, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_sgemm_8x12, float, float>(args); } ), -#ifdef ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #ifdef ARM_COMPUTE_ENABLE_BF16 // "fast mode" (BF16) kernels GemmImplementation<float, float>::with_estimate( @@ -289,7 +289,7 @@ GemmImplementation<float, float>::with_estimate( [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32_mla_6x16, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32_mla_6x16, float, float>(args); } ), -#endif // ENABLE_FIXED_FORMAT_KERNELS +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #endif // __aarch64__ #ifdef __arm__ @@ -318,7 +318,7 @@ const GemmImplementation<float, float> *gemm_implementation_list<float, float>() /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<float, float> gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<float, float, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); template KernelDescription get_gemm_method<float, float, Nothing>(const GemmArgs &args, const Nothing &); template std::vector<KernelDescription> get_compatible_kernels<float, float, Nothing> (const GemmArgs &args, const Nothing &); diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp index c41b0a5b3e..90e2f07607 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp @@ -450,7 +450,7 @@ public: } /* Make sure we've been set up correctly. */ - assert(_B_transposed); + assert(FixedFormat || _B_transposed); static_assert(std::is_same<To, Tloi>::value, "gemm_native: Operand types must be the same."); // static_assert(std::is_same<Tr, Tri>::value, "gemm_native: Result types must be the same."); diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp index 75fb1cb306..19c8fcadd3 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp @@ -306,9 +306,12 @@ std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, cons } template<typename Top, typename Tret, class OutputStage> -bool has_opt_gemm(const GemmArgs &args, const OutputStage &os) { +bool has_opt_gemm(WeightFormat &wf, const GemmArgs &args, const OutputStage &os) { const GemmImplementation<Top, Tret, OutputStage> *impl; - return find_implementation<Top, Tret, OutputStage>(args, os, impl); + const bool success = find_implementation<Top, Tret, OutputStage>(args, os, impl); + if (success) + wf = UniqueGemmCommon<Top, Tret>(impl->do_instantiate(args, os))->get_config().weight_format; + return success; } template<typename Top, typename Tret, class OutputStage> diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp index 3915861112..18d8fc9312 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp @@ -56,7 +56,7 @@ const GemmImplementation<int16_t, int32_t> *gemm_implementation_list<int16_t, in /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<int16_t, int32_t> gemm<int16_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm<int16_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<int16_t, int32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); template std::vector<KernelDescription> get_compatible_kernels<int16_t, int32_t, Nothing> (const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp index 0c68e4dd99..24507486ac 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp @@ -159,7 +159,7 @@ const GemmImplementation<int8_t, int32_t> *gemm_implementation_list<int8_t, int3 /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<int8_t, int32_t> gemm<int8_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm<int8_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<int8_t, int32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); template std::vector<KernelDescription> get_compatible_kernels<int8_t, int32_t, Nothing> (const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp index 6b813c7974..1d7b9c5b73 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp @@ -230,7 +230,7 @@ const GemmImplementation<int8_t, int8_t, Requantize32> *gemm_implementation_list } template UniqueGemmCommon<int8_t, int8_t> gemm<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); -template bool has_opt_gemm<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); +template bool has_opt_gemm<int8_t, int8_t, Requantize32>(WeightFormat &weight_format, const GemmArgs &args, const Requantize32 &os); template std::vector<KernelDescription> get_compatible_kernels<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp index 95139c2bf6..be7a4ee570 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp @@ -197,7 +197,7 @@ const GemmImplementation<uint8_t, uint8_t, Requantize32> *gemm_implementation_li } template UniqueGemmCommon<uint8_t, uint8_t> gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); -template bool has_opt_gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); +template bool has_opt_gemm<uint8_t, uint8_t, Requantize32>(WeightFormat &weight_format, const GemmArgs &args, const Requantize32 &os); template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp index 20cee556f0..fc836f9790 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp @@ -56,7 +56,7 @@ const GemmImplementation<uint16_t, uint32_t> *gemm_implementation_list<uint16_t, /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<uint16_t, uint32_t> gemm<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<uint16_t, uint32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); template std::vector<KernelDescription> get_compatible_kernels<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp index a2d2cc86f0..03e9cd6c1f 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp @@ -157,7 +157,7 @@ const GemmImplementation<uint8_t, uint32_t> *gemm_implementation_list<uint8_t, u /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<uint8_t, uint32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint32_t, Nothing> (const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/utils.hpp b/src/core/NEON/kernels/arm_gemm/utils.hpp index 18e124b83e..d7b5398488 100644 --- a/src/core/NEON/kernels/arm_gemm/utils.hpp +++ b/src/core/NEON/kernels/arm_gemm/utils.hpp @@ -24,7 +24,7 @@ #pragma once -#include "arm_gemm.hpp" +#include "src/cpu/kernels/assembly/arm_gemm.hpp" #include <cstddef> #include <limits> diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp index 247cb1d470..48fd7c6b43 100644 --- a/src/cpu/kernels/assembly/arm_gemm.hpp +++ b/src/cpu/kernels/assembly/arm_gemm.hpp @@ -47,6 +47,57 @@ enum class GemmMethod GEMM_HYBRID_QUANTIZED }; +/** Memory layouts for the weights tensor. + * + * * UNSPECIFIED is used to select kernels that do not run in + * variable weights mode. + * + * * ANY is used to query the kernel database to retrieve any of the + * kernels that runs in variable weights mode. Once a kernel is + * found, the specific format expected by the kernel can be + * retrieved by the user for reordering the weights tensor + * accordingly. + * + * The other values OHWIo{interleave_by}i{block_by} describe the + * memory layout of a 4D tensor with layout OHWI that has been + * transformed into a 4D tensor with dimensions O'HWI' where: + * + * O' = first multiple of {interleave_by} s.t. O<=O' + * I' = first multiple of {block_by} s.t. I<=I' + * + * The total size of the dst tensor is O' x H x W x I' + * + * The access function of the tensor with layout + * OHWIo{interleave_by}i{block_by} and size O'HWI' is a 6-parameter + * access function, where the 6 parameters are computed as follows: + * + * x5 = floor(o/{interleave_by}) RANGE [0, O'/{interleave_by} -1] SIZE: O'/{interleave_by} + * + * x4 = h RANGE [0, H-1] SIZE: H + * x3 = w RANGE [0, W-1] SIZE: W + * x2 = floor(i/{block_by}) RANGE [0, I'/{block_by} -1] SIZE: I'/{block_by} + * x1 = o%{interleave_by} RANGE [0, {interleave_by} -1] SIZE: {interleave_by} + * x0 = i%{block_by} RANGE [0, {block_by} -1] SIZE: {block_by} + * TOTAL SIZE: O' * H * W * I' + * + * 4D 6D + * ----------------- ----------------------------------- + * value(o, h, w, i) = x5 * H * W * I' * {interleave_by} + * + x4 * W * I' * {interleave_by} + * + x3 * I' * {interleave_by} + * + x2 * {interleave_by} * {block_by} + * + x1 * {block_by} + * + x0 + * + * Notice that in arm_gemm the 4D tensor of dimension O'HWI' created + * for the OHWIo{interleave_by}i{block_by} format is in reality seen + * as a 2D tensor, where the number of rows is O'/{interleave_by} + * and the number of columns is {interleave_by} * H * W * I'. + * + * The postfix *_bf16 is for the memory layout needed for the + * fast-mode kernels, in which the weights are passed in bfloat16 + * format. + */ enum class WeightFormat { UNSPECIFIED = 0x1, @@ -87,6 +138,69 @@ enum class WeightFormat OHWIo64i8 = 0x804000 }; +// OHWIo<interleave_by>i<block_by> +inline int interleave_by(const WeightFormat wf) +{ + return ((int)wf >> 8) & 0xFFF; +} +inline int block_by(const WeightFormat wf) +{ + return ((int)wf >> 20) & 0xF; +} +inline bool is_fixed_format(const WeightFormat wf) +{ + return wf != WeightFormat::UNSPECIFIED && wf != WeightFormat::ANY; +} + +inline std::string to_string(WeightFormat wf) +{ +#define __CASE_WEIGHT_FORMAT(wf) \ +case WeightFormat::wf: \ + return #wf; + switch(wf) + { + __CASE_WEIGHT_FORMAT(UNSPECIFIED) + __CASE_WEIGHT_FORMAT(ANY) + __CASE_WEIGHT_FORMAT(OHWI) + __CASE_WEIGHT_FORMAT(OHWIo2) + __CASE_WEIGHT_FORMAT(OHWIo4) + __CASE_WEIGHT_FORMAT(OHWIo8) + __CASE_WEIGHT_FORMAT(OHWIo16) + __CASE_WEIGHT_FORMAT(OHWIo32) + __CASE_WEIGHT_FORMAT(OHWIo64) + __CASE_WEIGHT_FORMAT(OHWIo128) + __CASE_WEIGHT_FORMAT(OHWIo4i2) + __CASE_WEIGHT_FORMAT(OHWIo4i2_bf16) + __CASE_WEIGHT_FORMAT(OHWIo8i2) + __CASE_WEIGHT_FORMAT(OHWIo8i2_bf16) + __CASE_WEIGHT_FORMAT(OHWIo16i2) + __CASE_WEIGHT_FORMAT(OHWIo16i2_bf16) + __CASE_WEIGHT_FORMAT(OHWIo32i2) + __CASE_WEIGHT_FORMAT(OHWIo32i2_bf16) + __CASE_WEIGHT_FORMAT(OHWIo64i2) + __CASE_WEIGHT_FORMAT(OHWIo64i2_bf16) + __CASE_WEIGHT_FORMAT(OHWIo4i4) + __CASE_WEIGHT_FORMAT(OHWIo4i4_bf16) + __CASE_WEIGHT_FORMAT(OHWIo8i4) + __CASE_WEIGHT_FORMAT(OHWIo8i4_bf16) + __CASE_WEIGHT_FORMAT(OHWIo16i4) + __CASE_WEIGHT_FORMAT(OHWIo16i4_bf16) + __CASE_WEIGHT_FORMAT(OHWIo32i4) + __CASE_WEIGHT_FORMAT(OHWIo32i4_bf16) + __CASE_WEIGHT_FORMAT(OHWIo64i4) + __CASE_WEIGHT_FORMAT(OHWIo64i4_bf16) + __CASE_WEIGHT_FORMAT(OHWIo2i8) + __CASE_WEIGHT_FORMAT(OHWIo4i8) + __CASE_WEIGHT_FORMAT(OHWIo8i8) + __CASE_WEIGHT_FORMAT(OHWIo16i8) + __CASE_WEIGHT_FORMAT(OHWIo32i8) + __CASE_WEIGHT_FORMAT(OHWIo64i8) + default: + return "invalid value"; + } +#undef __CASE_WEIGHT_FORMAT +} + struct KernelDescription { GemmMethod method = GemmMethod::DEFAULT; @@ -230,6 +344,6 @@ template <typename Top, typename Tret, class OutputStage = Nothing> std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {}); template <typename Top, typename Tret, class OutputStage = Nothing> -bool has_opt_gemm(const GemmArgs &args, const OutputStage & = {}); +bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {}); } // namespace arm_gemm diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp index 61cd11ece0..f3fff608dc 100644 --- a/src/cpu/operators/CpuGemm.cpp +++ b/src/cpu/operators/CpuGemm.cpp @@ -51,6 +51,7 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info) asm_info.activation_info = info.activation_info(); asm_info.fast_mode = info.fast_math(); asm_info.fixed_format = info.fixed_format(); + asm_info.weight_format = info.weight_format(); return asm_info; } @@ -177,7 +178,8 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens if(d->total_size() != 0) { - ARM_COMPUTE_RETURN_ERROR_ON(b->dimension(0) != d->dimension(0)); + // For fixed format we are expecting some kind of blocked format for B/RHS so the dimension won't necessarily match the result matrix any more. + ARM_COMPUTE_RETURN_ERROR_ON(!gemm_info.fixed_format() && b->dimension(0) != d->dimension(0)); if(gemm_info.depth_output_gemm3d() != 0) { if(gemm_info.reinterpret_input_as_3d()) @@ -277,7 +279,7 @@ void CpuGemm::run(ITensorPack &tensors) auto c = tensors.get_const_tensor(ACL_SRC_2); auto d = tensors.get_tensor(ACL_DST); - if(_asm_glue->is_configured()) + if(_asm_glue && _asm_glue->is_configured()) { // Pass c to asm dispatch only if it's the bias tensor ITensorPack asm_pack = tensors; @@ -343,7 +345,7 @@ void CpuGemm::prepare(ITensorPack &tensors) { if(!_is_prepared) { - if(_asm_glue->is_configured()) + if(_asm_glue && _asm_glue->is_configured()) { _asm_glue->prepare(tensors); } @@ -365,5 +367,18 @@ experimental::MemoryRequirements CpuGemm::workspace() const { return _aux_mem; } + +Status CpuGemm::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, + const GEMMInfo &gemm_info) +{ + const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info); + + return CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, asm_info); +} + +bool CpuGemm::isVarWeightsKernel() const +{ + return _asm_glue && _asm_glue->isVarWeightsKernel(); +} } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/operators/CpuGemm.h b/src/cpu/operators/CpuGemm.h index 334ab6c647..b37ab73485 100644 --- a/src/cpu/operators/CpuGemm.h +++ b/src/cpu/operators/CpuGemm.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -101,11 +101,29 @@ public: static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo()); + /** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters. + * + * This method has the same use of @ref + * NEGEMMConvolutionLayer::has_opt_impl, with the only caveat that + * the value of arm_gemm::WeightFormat need to be passed via the + * parameter gemm_info. + */ + static Status has_opt_impl(arm_gemm::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, + const GEMMInfo &gemm_info = GEMMInfo()); + // Inherited methods overridden: - void run(ITensorPack &tensors) override; - void prepare(ITensorPack &constants) override; + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &constants) override; experimental::MemoryRequirements workspace() const override; + /** Indicates if the convolution executes in variable weights mode. + * + * When ACL executes convolution in variable weights mode, it does + * not perform any processing of the weights tensor. Instead, it + * utilizes the data as it is given by the user. + */ + bool isVarWeightsKernel() const; + private: enum AuxTensorIdx { diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp index c021d31059..0174d0eed3 100644 --- a/src/cpu/operators/CpuGemmConv2d.cpp +++ b/src/cpu/operators/CpuGemmConv2d.cpp @@ -99,15 +99,15 @@ CpuGemmConv2d::CpuGemmConv2d() CpuGemmConv2d::~CpuGemmConv2d() = default; void CpuGemmConv2d::configure_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const ActivationLayerInfo &act_info, - bool enable_fast_math, int gemm_3d_depth) + bool enable_fast_math, int gemm_3d_depth, bool fixed_format, arm_gemm::WeightFormat weight_format) { ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights); - ARM_COMPUTE_ERROR_THROW_ON(validate_mm(src, weights, biases, dst, act_info, enable_fast_math, gemm_3d_depth, _skip_im2col)); + ARM_COMPUTE_ERROR_THROW_ON(validate_mm(src, weights, biases, dst, act_info, enable_fast_math, gemm_3d_depth, _skip_im2col, fixed_format, weight_format)); // Create GEMMInfo structure const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth, _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, - false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info); + false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, experimental::PostOpList<ITensorInfo *>(), fixed_format, weight_format); // Supported activations in GEMM const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = { ActivationLayerInfo::ActivationFunction::RELU, @@ -156,7 +156,8 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src, const ITensorInfo *weig quantization::calculate_quantized_multipliers(iqinfo, wqinfo, oqinfo, output_info); _mm_gemmlowp = std::make_unique<CpuGemmLowpMatrixMultiplyCore>(); - _mm_gemmlowp->configure(&tmp_src, &tmp_weights, biases, dst, GEMMInfo(false, false, true, gemm_3d_depth, _skip_im2col, false, output_info, false, enable_fast_math, false, act_info)); + _mm_gemmlowp->configure(&tmp_src, &tmp_weights, biases, dst, GEMMInfo(false, false, true, gemm_3d_depth, _skip_im2col, false, output_info, false, enable_fast_math, false, act_info, + experimental::PostOpList<ITensorInfo *>(), fixed_format, weight_format)); auto mm_mem_req = _mm_gemmlowp->workspace(); for(unsigned int cont = 0; cont < mm_mem_req.size(); ++cont) @@ -178,7 +179,7 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src, const ITensorInfo *weig } Status CpuGemmConv2d::validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, - const ActivationLayerInfo &act_info, bool enable_fast_math, int gemm_3d_depth, bool skip_im2col) + const ActivationLayerInfo &act_info, bool enable_fast_math, int gemm_3d_depth, bool skip_im2col, bool fixed_format, arm_gemm::WeightFormat weight_format) { const DataType data_type = src->data_type(); const bool is_quantized = is_data_type_quantized_asymmetric(data_type); @@ -187,7 +188,7 @@ Status CpuGemmConv2d::validate_mm(const ITensorInfo *src, const ITensorInfo *wei // Create GEMMInfo structure const GEMMInfo gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, - false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info); + false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, experimental::PostOpList<ITensorInfo *>(), fixed_format, weight_format); if(is_quantized) { @@ -227,6 +228,7 @@ Status CpuGemmConv2d::validate_mm(const ITensorInfo *src, const ITensorInfo *wei std::unique_ptr<ITensorInfo> weights_qa = weights->clone(); input_qa->set_quantization_info(QuantizationInfo(iqinfo.uniform().scale, -iqinfo.uniform().offset)); weights_qa->set_quantization_info(QuantizationInfo(wqinfo.uniform().scale, -wqinfo.uniform().offset)); + return CpuGemmLowpMatrixMultiplyCore::validate(input_qa.get(), weights_qa.get(), biases, dst, GEMMInfo(false, false, true, gemm_3d_depth, skip_im2col, false, output_info, false, enable_fast_math, false, act_info)); } @@ -294,6 +296,7 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, const ITensorInfo *weights kernel_height, conv_info, dilation); + ARM_COMPUTE_ERROR_ON_MSG((dst->dimension(idx_width) != conv_w) || (dst->dimension(idx_height) != conv_h), "Output shape does not match the expected one"); @@ -357,7 +360,8 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, const ITensorInfo *weights // Configure GEMM // In case we need to skip col2im, GEMM3D (gemm_3d_depth != 0) must be called in order to avoid reshaping the output matrix const unsigned int gemm_3d_depth = _skip_col2im ? conv_h : 0; - configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, act_info, enable_fast_math, gemm_3d_depth); + const bool fixed_format = weights_info.weight_format() != arm_gemm::WeightFormat::UNSPECIFIED; + configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, act_info, enable_fast_math, gemm_3d_depth, fixed_format, weights_info.weight_format()); if(!_skip_col2im && _data_layout == DataLayout::NCHW) { @@ -384,6 +388,38 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, const ITensorInfo *weights _aux_mem[GemmOutput] = MemoryInfo(offset_int_vec(GemmOutput), MemoryLifetime::Temporary, _gemm_output.total_size()); } +Status CpuGemmConv2d::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, + const PadStrideInfo &conv_info, + const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, const bool enable_fast_math) +{ + const DataLayout data_layout = src->data_layout(); + const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + const unsigned int kernel_width = weights->dimension(idx_width); + const unsigned int kernel_height = weights->dimension(idx_height); + unsigned int conv_w = 0; + unsigned int conv_h = 0; + std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width), + src->dimension(idx_height), + kernel_width, + kernel_height, + conv_info, + dilation); + + const CpuGemmConv2d::SkipInfo skip_info = CpuGemmConv2d::skip_im_col_info(src, weights, conv_info, + dilation, act_info); + + const bool skip_im2col = skip_info.skip_im2col; + const bool skip_col2im = skip_info.skip_col2im; + const unsigned int gemm_3d_depth = skip_col2im ? conv_h : 0; + const bool fixed_format = weights_info.weight_format() != arm_gemm::WeightFormat::UNSPECIFIED; + const GEMMInfo gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, + gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, + false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, experimental::PostOpList<ITensorInfo *>(), fixed_format, weights_info.weight_format()); + + return CpuGemm::has_opt_impl(expected_weight_format, src, weights, biases, dst, gemm_info); +} + Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const PadStrideInfo &conv_info, const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups) { @@ -450,7 +486,7 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, biases); } - ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(idx_kernels)); + ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != dst->dimension(idx_channel)); ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1); } @@ -472,7 +508,7 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight im2col_reshaped_info = TensorInfo(shape_im2col, 1, data_type); im2col_reshaped_info.set_quantization_info(src->quantization_info()); - ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuIm2ColKernel::validate(src, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation)); + ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuIm2ColKernel::validate(src, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation, 1)); gemm_input_to_use = &im2col_reshaped_info; } @@ -490,8 +526,11 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight info_gemm = TensorInfo(dst->tensor_shape(), 1, output_data_type); } info_gemm.set_quantization_info(dst->quantization_info()).set_data_layout(src->data_layout()); - gemm_output_to_use = &info_gemm; - ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, act_info, enable_fast_math, skip_col2im ? conv_h : 0, skip_im2col)); + gemm_output_to_use = &info_gemm; + const bool fixed_format = weights_info.weight_format() != arm_gemm::WeightFormat::UNSPECIFIED; + + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, act_info, enable_fast_math, skip_col2im ? conv_h : 0, skip_im2col, fixed_format, + weights_info.weight_format())); // Validate Col2Im/ReshapeLayer if(!skip_col2im && (data_layout == DataLayout::NCHW)) @@ -548,7 +587,10 @@ void CpuGemmConv2d::run(ITensorPack &tensors) // Runs CpuGemm or CpuGemmLowpMatrixMultiplyCore functions ITensorPack pack_mm = tensors; pack_mm.add_const_tensor(TensorType::ACL_SRC_0, gemm_input_to_use); - pack_mm.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get()); + if(!this->isVarWeightsKernel()) + { + pack_mm.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get()); + } pack_mm.add_tensor(TensorType::ACL_DST, gemm_output_to_use); if(_is_quantized) { @@ -598,6 +640,15 @@ void CpuGemmConv2d::prepare(ITensorPack &tensors) { if(!_is_prepared) { + // Variable weights executions that use fixed-format kernels + // need no reshaping of the weights. + if(this->isVarWeightsKernel()) + { + _is_quantized ? _mm_gemmlowp->prepare(tensors) : _mm_gemm->prepare(tensors); + _is_prepared = true; + return; + } + // Run weights reshaping and mark original weights tensor as unused CpuAuxTensorHandler weights_reshaped(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors); auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1); @@ -608,12 +659,9 @@ void CpuGemmConv2d::prepare(ITensorPack &tensors) }; NEScheduler::get().schedule_op(_weights_reshape_kernel.get(), 3, _weights_reshape_kernel->window(), pack); weights->mark_as_unused(); - - // Prepare GEMM ITensorPack gemm_pack = tensors; gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, weights_reshaped.get()); _is_quantized ? _mm_gemmlowp->prepare(gemm_pack) : _mm_gemm->prepare(gemm_pack); - _is_prepared = true; } } @@ -621,5 +669,9 @@ experimental::MemoryRequirements CpuGemmConv2d::workspace() const { return _aux_mem; } +bool CpuGemmConv2d::isVarWeightsKernel() const +{ + return _mm_gemm && _mm_gemm->isVarWeightsKernel(); +} } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/operators/CpuGemmConv2d.h b/src/cpu/operators/CpuGemmConv2d.h index aec4a2ffa5..f8f0bce048 100644 --- a/src/cpu/operators/CpuGemmConv2d.h +++ b/src/cpu/operators/CpuGemmConv2d.h @@ -117,6 +117,17 @@ public: const WeightsInfo &weights_info = WeightsInfo(), const Size2D &dilation = Size2D(1U, 1U), const ActivationLayerInfo &act_info = ActivationLayerInfo(), bool enable_fast_math = false, unsigned int num_groups = 1); + /** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters. + * + * The paramter list is the same as @ref NEGEMMConvolutionLayer::has_opt_impl + * + * @return a status. + */ + static Status has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, + const PadStrideInfo &conv_info, + const WeightsInfo &weights_info = WeightsInfo(), const Size2D &dilation = Size2D(1U, 1U), const ActivationLayerInfo &act_info = ActivationLayerInfo(), + const bool enable_fast_math = false); + // Inherited methods overridden: void run(ITensorPack &tensors) override; void prepare(ITensorPack &tensors) override; @@ -135,9 +146,11 @@ private: * @param[in] enable_fast_math (Optional) Enable fast math computation. In case this flag were set, the function could dispatch the fastest implementation * available which may introduce a drop of accuracy as well. Default is false * @param[in] gemm_3d_depth (Optional) Depth of GEMM 3D (Defaults to 1) + * @param[in] fixed_format (Optional) Select GEMM execution with variable weights. + * @param[in] weight_format (Optional) The layout to be used for the weights tensor when running GEMM with variable weights. */ void configure_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *output, const ActivationLayerInfo &act_info = ActivationLayerInfo(), - bool enable_fast_math = false, int gemm_3d_depth = 1); + bool enable_fast_math = false, int gemm_3d_depth = 1, bool fixed_format = false, arm_gemm::WeightFormat weight_format = arm_gemm::WeightFormat::UNSPECIFIED); /** Static function to check if given info will lead to a valid configuration of @ref NEGEMMConvolutionLayer matrix multiply routines * * @param[in] src Input tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/BFLOAT16/F16/F32. @@ -151,11 +164,13 @@ private: * available which may introduce a drop of accuracy as well. Default is false * @param[in] gemm_3d_depth (Optional) Depth of GEMM 3D (Defaults to 1) * @param[in] skip_im2col (Optional) Flag which specifies if im2col has to be skipped. i.e. 1x1 convolution with NHWC data layout. (Default to false) + * @param[in] fixed_format (Optional) Select GEMM execution with variable weights. + * @param[in] weight_format (Optional) The layout to be used for the weights tensor when running GEMM with variable weights. * * @return a status */ static Status validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const ActivationLayerInfo &act_info = ActivationLayerInfo(), - bool enable_fast_math = false, int gemm_3d_depth = 1, bool skip_im2col = false); + bool enable_fast_math = false, int gemm_3d_depth = 1, bool skip_im2col = false, bool fixed_format = false, arm_gemm::WeightFormat weight_format = arm_gemm::WeightFormat::UNSPECIFIED); /** Static function to check if GEMM3D is supported in @ref NEGEMM or in @ref CpuGemmMLowpMatrixMultiplyCore * * @param[in] src Input tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/BFLOAT16/F16/F32. @@ -187,6 +202,11 @@ private: static SkipInfo skip_im_col_info(const ITensorInfo *src, const ITensorInfo *weights, const PadStrideInfo &conv_info, const Size2D &dilation, const ActivationLayerInfo &act_info); + /** Indicates if the convolution executes in variable weights mode. + * + * Similar to @ref CpuGemm::isVarWeightsKernel + */ + bool isVarWeightsKernel() const; enum AuxTensorIdx { // CpuGemmLowpMatrixMultiplyCore has up to 8 internal tensors diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index 787ea95372..5694a3d9ee 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -160,6 +160,13 @@ public: void prepare(ITensorPack &tensors) override; bool is_configured() const override; experimental::MemoryRequirements workspace() const override; + bool isVarWeightsKernel() const override + { + if(!_gemm_kernel_asm) + return false; + const arm_gemm::WeightFormat wf = _gemm_kernel_asm->get_config().weight_format; + return wf != arm_gemm::WeightFormat::UNSPECIFIED && wf != arm_gemm::WeightFormat::ANY; + } private: enum AuxTensorIdx @@ -420,6 +427,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors) // Pretranspose B if required if(_gemm_kernel_asm->B_pretranspose_required()) { + // Fixed format kernels need no pretranspose. + ARM_COMPUTE_ERROR_ON(arm_gemm::is_fixed_format(_gemm_kernel_asm->get_config().weight_format)); const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); const auto in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes()); const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); @@ -483,7 +492,24 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors) // Check if B is pre-tranposed and de-reference if not if(!_gemm_kernel_asm->B_is_pretransposed()) { - ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); + ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); + const arm_gemm::WeightFormat wf = _gemm_kernel_asm->get_config().weight_format; + if(is_fixed_format(wf)) + { + // The 4D tensor of dimension O'HWI' created for the + // OHWIo<interleave_by>i<block_by> format is in reality seen + // as a 2D tensor at arm_gemm level, where the rows are + // O'/<interleave_by> and the columns are <interleave_by> * + // H * W * I'. + ITensorInfo *tensor_info = b->info(); + const DataLayout data_layout = tensor_info->data_layout(); + const TensorShape tensor_shape = tensor_info->tensor_shape(); + const int H = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)]; + const int W = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)]; + const int Ip = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; + const int interleave_by = arm_gemm::interleave_by(wf); + ldb = (interleave_by * H * W * Ip); + } multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes()); } @@ -576,7 +602,9 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge const CPUInfo &ci = NEScheduler::get().cpu_info(); unsigned int num_threads = NEScheduler::get().num_threads(); - arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode); + arm_gemm::GemmConfig cfg; + cfg.weight_format = info.weight_format; + arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg); // Create arm_gemm fallback auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>(); @@ -594,7 +622,9 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> & const CPUInfo &ci = NEScheduler::get().cpu_info(); const unsigned int num_threads = NEScheduler::get().num_threads(); - arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode); + arm_gemm::GemmConfig cfg; + cfg.weight_format = info.weight_format; + arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg); // Create arm_gemm fallback auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>(); @@ -635,7 +665,8 @@ CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch() { } -Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info) +Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, + const AsmGemmInfo &info) { ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); ARM_COMPUTE_UNUSED(c); @@ -643,12 +674,14 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensor Params p = extract_parameters(a, b, d, info); const CPUInfo &ci = NEScheduler::get().cpu_info(); unsigned int num_threads = NEScheduler::get().num_threads(); + arm_gemm::GemmConfig cfg; + cfg.weight_format = info.weight_format; - arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode); + arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, &cfg); switch(a->data_type()) { case DataType::F32: - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(args, {})), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(expected_weight_format, args, {})), "We could not find an optimized kernel for F32 input"); break; #ifdef __aarch64__ @@ -656,12 +689,12 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensor case DataType::QASYMM8: if(d->data_type() == DataType::S32) { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(args, {})), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(expected_weight_format, args, {})), "We could not find an optimized kernel for U8/QASYMM8 input and S32 output"); } else { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(args, {})), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(expected_weight_format, args, {})), "We could not find an optimized kernel for U8 input and U8 output"); } break; @@ -669,12 +702,12 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensor case DataType::QASYMM8_SIGNED: if(d->data_type() == DataType::S32) { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(args, {})), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(expected_weight_format, args, {})), "We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output"); } else { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(args, {})), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(expected_weight_format, args, {})), "We could not find an optimized kernel for S8 input and S32 output"); } break; @@ -689,7 +722,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensor #endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(args, {})), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(expected_weight_format, args, {})), "We could not find an optimized kernel for BFLOAT16 input and F32 output"); break; #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ @@ -729,7 +762,17 @@ Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input"); - return CpuGemmAssemblyDispatch::has_opt_impl(a, b, c, d, info); + arm_gemm::WeightFormat expected_weight_format; + const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info); + if((bool)ret && expected_weight_format != arm_gemm::WeightFormat::ANY) + { + // Correctness check: if the format expected by the kernel is + // not "any", make sure that the one found matches the format + // intended by the caller. + ARM_COMPUTE_RETURN_ERROR_ON_MSG((expected_weight_format != info.weight_format), + "The format expected by the kernel does not correspond with the one requested by the user."); + } + return ret; } bool CpuGemmAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation) @@ -801,7 +844,7 @@ void CpuGemmAssemblyDispatch::prepare(ITensorPack &tensors) bool CpuGemmAssemblyDispatch::is_configured() const { - return _arm_gemm != nullptr && _arm_gemm->is_configured(); + return _arm_gemm && _arm_gemm->is_configured(); } void CpuGemmAssemblyDispatch::run(ITensorPack &tensors) diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h index 3c25866f25..4ef108d430 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h @@ -53,6 +53,7 @@ struct AsmGemmInfo float padding_value{ 0.f }; bool fast_mode{ false }; bool fixed_format{ false }; + arm_gemm::WeightFormat weight_format{ arm_gemm::WeightFormat::UNSPECIFIED }; }; /** Assembly kernel glue */ @@ -73,6 +74,7 @@ public: virtual void prepare(ITensorPack &tensors) = 0; virtual experimental::MemoryRequirements workspace() const = 0; virtual bool is_configured() const = 0; + virtual bool isVarWeightsKernel() const = 0; virtual ~IFallback() = default; }; @@ -101,15 +103,14 @@ public: /** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters. * - * @param[in] a Input tensor info (Matrix A) - * @param[in] b Input tensor info (Matrix B) - * @param[in] c Input tensor info (Matrix C) used to pass the bias for quantized calculations - * @param[in] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0. - * @param[in] info GEMM meta-data + * This method has the same use of @ref + * NEGEMMConvolutionLayer::has_opt_impl, with the only caveat that + * the value of arm_gemm::WeightFormat need to be passed via the + * parameter info. * * @return a status. */ - static Status has_opt_impl(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info); + static Status has_opt_impl(arm_gemm::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info); /** Checks if activation is supported by the gemm assembly dispatcher * * @param[in] activation Activation to check @@ -122,6 +123,14 @@ public: * @return True if the function is configured and ready to run */ bool is_configured() const; + /** Indicates if the convolution executes in variable weights mode. + * + * Similar to @ref CpuGemm::isVarWeightsKernel + */ + bool isVarWeightsKernel() const + { + return _arm_gemm && _arm_gemm->isVarWeightsKernel(); + } // Inherited methods overridden: void prepare(ITensorPack &tensors) override; diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp index c780d63763..13635c6e34 100644 --- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -58,6 +58,7 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); + _impl->weights = weights; _impl->op = std::make_unique<cpu::CpuGemmConv2d>(); _impl->op->configure(input->info(), weights->info(), (biases != nullptr ? biases->info() : nullptr), output->info(), conv_info, weights_info, dilation, act_info, enable_fast_math, num_groups); @@ -79,6 +80,13 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI return cpu::CpuGemmConv2d::validate(input, weights, biases, output, conv_info, weights_info, dilation, act_info, enable_fast_math, num_groups); } +Status NEGEMMConvolutionLayer::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, + const PadStrideInfo &conv_info, + const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, const bool enable_fast_math) +{ + return cpu::CpuGemmConv2d::has_opt_impl(expected_weight_format, src, weights, biases, dst, conv_info, weights_info, dilation, act_info, enable_fast_math); +} + void NEGEMMConvolutionLayer::run() { prepare(); |