From e33c55640b5e6e6af193a41a3376e2a01a321a35 Mon Sep 17 00:00:00 2001 From: "Francesco.Petrogalli@arm.com" Date: Thu, 31 Mar 2022 17:55:35 +0000 Subject: [arm_gemm] Use static validate to find arm_gemm kernels. The static method `CpuGemmAssemblyDispatch::validate` should look into the list of the available kernels to make sure the one requested by the user was found. Formatting changes in the files touched by the patch have been automatically inserted by the formatting script. Resolves: ONCPUML-840 Change-Id: Icd650a30e142284a942c64f8a2b72441ee7b3f4e Signed-off-by: Francesco.Petrogalli@arm.com Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7375 Tested-by: Arm Jenkins Reviewed-by: Giorgio Arena Comments-Addressed: Arm Jenkins --- src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp | 1 + src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp | 1 + src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp | 1 + .../NEON/kernels/arm_gemm/gemm_implementation.hpp | 6 ++ src/core/NEON/kernels/arm_gemm/gemm_int16.cpp | 1 + src/core/NEON/kernels/arm_gemm/gemm_int8.cpp | 1 + src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp | 1 + src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp | 1 + src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp | 1 + src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp | 1 + src/cpu/kernels/assembly/arm_gemm.hpp | 5 +- .../operators/internal/CpuGemmAssemblyDispatch.cpp | 84 +++++++++++++++++++--- .../operators/internal/CpuGemmAssemblyDispatch.h | 28 +++++--- 13 files changed, 114 insertions(+), 18 deletions(-) diff --git a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp index f4af587898..dd72fb5901 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp @@ -144,6 +144,7 @@ const GemmImplementation *gemm_implementation_list gemm(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm(const GemmArgs &args, const Nothing &); template std::vector get_compatible_kernels(const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp index a502262de3..42f4528066 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp @@ -108,6 +108,7 @@ const GemmImplementation<__fp16, __fp16> *gemm_implementation_list<__fp16, __fp1 /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); template std::vector get_compatible_kernels<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index 8b855ab07c..69a2803903 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -232,6 +232,7 @@ const GemmImplementation *gemm_implementation_list() /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon gemm(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm(const GemmArgs &args, const Nothing &); template std::vector get_compatible_kernels (const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp index 4d7f7983f8..cb3ff7aa29 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp @@ -235,6 +235,12 @@ std::vector get_compatible_kernels(const GemmArgs &args, cons return res; } +template +bool has_opt_gemm(const GemmArgs &args, const OutputStage &os) { + const GemmImplementation *impl; + return find_implementation(args, os, impl); +} + template UniqueGemmCommon gemm(const GemmArgs &args, const OutputStage &os) { const GemmImplementation *impl; diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp index d650116bf7..3915861112 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp @@ -56,6 +56,7 @@ const GemmImplementation *gemm_implementation_list gemm(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm(const GemmArgs &args, const Nothing &); template std::vector get_compatible_kernels (const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp index a1134559df..0c68e4dd99 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp @@ -159,6 +159,7 @@ const GemmImplementation *gemm_implementation_list gemm(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm(const GemmArgs &args, const Nothing &); template std::vector get_compatible_kernels (const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp index 1532816b74..6b813c7974 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp @@ -230,6 +230,7 @@ const GemmImplementation *gemm_implementation_list } template UniqueGemmCommon gemm(const GemmArgs &args, const Requantize32 &os); +template bool has_opt_gemm(const GemmArgs &args, const Requantize32 &os); template std::vector get_compatible_kernels(const GemmArgs &args, const Requantize32 &os); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp index a80766bad6..95139c2bf6 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp @@ -197,6 +197,7 @@ const GemmImplementation *gemm_implementation_li } template UniqueGemmCommon gemm(const GemmArgs &args, const Requantize32 &os); +template bool has_opt_gemm(const GemmArgs &args, const Requantize32 &os); template std::vector get_compatible_kernels(const GemmArgs &args, const Requantize32 &os); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp index d459df8126..20cee556f0 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp @@ -56,6 +56,7 @@ const GemmImplementation *gemm_implementation_list gemm(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm(const GemmArgs &args, const Nothing &); template std::vector get_compatible_kernels(const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp index f2d46d5415..a2d2cc86f0 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp @@ -157,6 +157,7 @@ const GemmImplementation *gemm_implementation_list gemm(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm(const GemmArgs &args, const Nothing &); template std::vector get_compatible_kernels (const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp index e38cc09202..200e04f9a8 100644 --- a/src/cpu/kernels/assembly/arm_gemm.hpp +++ b/src/cpu/kernels/assembly/arm_gemm.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -187,4 +187,7 @@ UniqueGemmCommon gemm(const GemmArgs &args, const OutputStage & = {}) template std::vector get_compatible_kernels(const GemmArgs &args, const OutputStage & = {}); +template +bool has_opt_gemm(const GemmArgs &args, const OutputStage & = {}); + } // namespace arm_gemm diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index 657f3b8e35..496b55ec1c 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -156,8 +156,8 @@ public: const std::vector &multipliers); // Inherited methods overridden: - void run(ITensorPack &tensors) override; - void prepare(ITensorPack &tensors) override; + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; bool is_configured() const override; experimental::MemoryRequirements workspace() const override; @@ -203,12 +203,12 @@ private: /** Indirect buffer */ std::unique_ptr _indirect_arg{}; std::unique_ptr _indirect_buf{}; - std::vector _indirect_pad{}; - arm_gemm::ConvolutionParameters _cp{}; - experimental::MemoryRequirements _aux_mem{ Count }; - bool _B_pretranspose_required{ false }; - bool _is_b_constant{ true }; - bool _is_c_constant{ true }; + std::vector _indirect_pad{}; + arm_gemm::ConvolutionParameters _cp{}; + experimental::MemoryRequirements _aux_mem{ Count }; + bool _B_pretranspose_required{ false }; + bool _is_b_constant{ true }; + bool _is_c_constant{ true }; }; template @@ -635,6 +635,72 @@ CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch() { } +Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); + ARM_COMPUTE_UNUSED(c); + arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(info.activation_info); + Params p = extract_parameters(a, b, d, info); + const CPUInfo &ci = NEScheduler::get().cpu_info(); + unsigned int num_threads = NEScheduler::get().num_threads(); + + arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fast_mode); + switch(a->data_type()) + { + case DataType::F32: + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(args, {})), + "We could not find an optimized kernel for F32 input"); + break; +#ifdef __aarch64__ + case DataType::U8: + case DataType::QASYMM8: + if(d->data_type() == DataType::S32) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::gemm(args, {})), + "We could not find an optimized kernel for U8/QASYMM8 input and S32 output"); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(args, {})), + "We could not find an optimized kernel for U8 input and U8 output"); + } + break; + case DataType::S8: + case DataType::QASYMM8_SIGNED: + if(d->data_type() == DataType::S32) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(args, {})), + "We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output"); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(args, {})), + "We could not find an optimized kernel for S8 input and S32 output"); + } + break; +#endif /* __aarch64__ */ +#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) + case DataType::BFLOAT16: + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(args, {})), + "We could not find an optimized kernel for BFLOAT16 input and F32 output"); + break; + } +#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case DataType::F16: + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(args, {})), + "We could not find an optimized kernel for BFLOAT16 input and F32 output"); + break; +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ + default: + ARM_COMPUTE_RETURN_ERROR_ON_MSG(true, "Usupported type. Could not find a kernel"); + break; + } + + return Status{}; +} + Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info) { ARM_COMPUTE_UNUSED(c, info); @@ -663,7 +729,7 @@ Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input"); - return Status{}; + return CpuGemmAssemblyDispatch::has_opt_impl(a, b, c, d, info); } bool CpuGemmAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation) diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h index a50f3634c2..74359eee72 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -68,11 +68,11 @@ public: class IFallback { public: - virtual void run(ITensorPack &tensors) = 0; - virtual void prepare(ITensorPack &tensors) = 0; - virtual experimental::MemoryRequirements workspace() const = 0; - virtual bool is_configured() const = 0; - virtual ~IFallback() = default; + virtual void run(ITensorPack &tensors) = 0; + virtual void prepare(ITensorPack &tensors) = 0; + virtual experimental::MemoryRequirements workspace() const = 0; + virtual bool is_configured() const = 0; + virtual ~IFallback() = default; }; public: @@ -97,6 +97,18 @@ public: * @return a status. */ static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info); + + /** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters. + * + * @param[in] a Input tensor info (Matrix A) + * @param[in] b Input tensor info (Matrix B) + * @param[in] c Input tensor info (Matrix C) used to pass the bias for quantized calculations + * @param[in] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0. + * @param[in] info GEMM meta-data + * + * @return a status. + */ + static Status has_opt_impl(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info); /** Checks if activation is supported by the gemm assembly dispatcher * * @param[in] activation Activation to check @@ -111,8 +123,8 @@ public: bool is_configured() const; // Inherited methods overridden: - void prepare(ITensorPack &tensors) override; - void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; + void run(ITensorPack &tensors) override; experimental::MemoryRequirements workspace() const override; private: -- cgit v1.2.1