diff options
author | Francesco.Petrogalli@arm.com <francesco.petrogalli@arm.com> | 2022-03-31 17:55:35 +0000 |
---|---|---|
committer | Francesco Petrogalli <francesco.petrogalli@arm.com> | 2022-04-06 13:41:41 +0000 |
commit | e33c55640b5e6e6af193a41a3376e2a01a321a35 (patch) | |
tree | 3a8282410d5f8e7f57276cbca9e2dfa6d14115d7 /src/core/NEON | |
parent | 4c17ba951b76e97102f101a88edbf012b722c732 (diff) | |
download | ComputeLibrary-e33c55640b5e6e6af193a41a3376e2a01a321a35.tar.gz |
[arm_gemm] Use static validate to find arm_gemm kernels.
The static method `CpuGemmAssemblyDispatch::validate` should look into
the list of the available kernels to make sure the one requested by
the user was found.
Formatting changes in the files touched by the patch have been
automatically inserted by the formatting script.
Resolves: ONCPUML-840
Change-Id: Icd650a30e142284a942c64f8a2b72441ee7b3f4e
Signed-off-by: Francesco.Petrogalli@arm.com <francesco.petrogalli@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7375
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp | 1 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp | 1 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp | 1 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp | 6 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_int16.cpp | 1 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_int8.cpp | 1 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp | 1 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp | 1 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp | 1 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp | 1 |
10 files changed, 15 insertions, 0 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp index f4af587898..dd72fb5901 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp @@ -144,6 +144,7 @@ const GemmImplementation<bfloat16, float> *gemm_implementation_list<bfloat16, fl /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<bfloat16, float> gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &); template std::vector<KernelDescription> get_compatible_kernels<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp index a502262de3..42f4528066 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp @@ -108,6 +108,7 @@ const GemmImplementation<__fp16, __fp16> *gemm_implementation_list<__fp16, __fp1 /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); template std::vector<KernelDescription> get_compatible_kernels<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index 8b855ab07c..69a2803903 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -232,6 +232,7 @@ const GemmImplementation<float, float> *gemm_implementation_list<float, float>() /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<float, float> gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &); template std::vector<KernelDescription> get_compatible_kernels<float, float, Nothing> (const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp index 4d7f7983f8..cb3ff7aa29 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp @@ -236,6 +236,12 @@ std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, cons } template<typename Top, typename Tret, class OutputStage> +bool has_opt_gemm(const GemmArgs &args, const OutputStage &os) { + const GemmImplementation<Top, Tret, OutputStage> *impl; + return find_implementation<Top, Tret, OutputStage>(args, os, impl); +} + +template<typename Top, typename Tret, class OutputStage> UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage &os) { const GemmImplementation<Top, Tret, OutputStage> *impl; diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp index d650116bf7..3915861112 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp @@ -56,6 +56,7 @@ const GemmImplementation<int16_t, int32_t> *gemm_implementation_list<int16_t, in /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<int16_t, int32_t> gemm<int16_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<int16_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &); template std::vector<KernelDescription> get_compatible_kernels<int16_t, int32_t, Nothing> (const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp index a1134559df..0c68e4dd99 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp @@ -159,6 +159,7 @@ const GemmImplementation<int8_t, int32_t> *gemm_implementation_list<int8_t, int3 /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<int8_t, int32_t> gemm<int8_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<int8_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &); template std::vector<KernelDescription> get_compatible_kernels<int8_t, int32_t, Nothing> (const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp index 1532816b74..6b813c7974 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp @@ -230,6 +230,7 @@ const GemmImplementation<int8_t, int8_t, Requantize32> *gemm_implementation_list } template UniqueGemmCommon<int8_t, int8_t> gemm<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); +template bool has_opt_gemm<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); template std::vector<KernelDescription> get_compatible_kernels<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp index a80766bad6..95139c2bf6 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp @@ -197,6 +197,7 @@ const GemmImplementation<uint8_t, uint8_t, Requantize32> *gemm_implementation_li } template UniqueGemmCommon<uint8_t, uint8_t> gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); +template bool has_opt_gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp index d459df8126..20cee556f0 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp @@ -56,6 +56,7 @@ const GemmImplementation<uint16_t, uint32_t> *gemm_implementation_list<uint16_t, /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<uint16_t, uint32_t> gemm<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &); template std::vector<KernelDescription> get_compatible_kernels<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp index f2d46d5415..a2d2cc86f0 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp @@ -157,6 +157,7 @@ const GemmImplementation<uint8_t, uint32_t> *gemm_implementation_list<uint8_t, u /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &); template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint32_t, Nothing> (const GemmArgs &args, const Nothing &); } // namespace arm_gemm |