diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp | 92 |
1 files changed, 34 insertions, 58 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp index 9194bdd4d4..1a90e96140 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -38,75 +38,51 @@ namespace arm_gemm { -#ifdef __ARM_FEATURE_SVE -class GemmImpl_gemm_fp16_interleaved_fp16 : public GemmImplementation<__fp16, __fp16> { -public: - - UniqueGemmCommon<__fp16, __fp16> instantiate(const GemmArgs<__fp16> &args) override { - return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<interleaved_fp16_mla_3VLx8, __fp16, __fp16>(args)); - } - - GemmImpl_gemm_fp16_interleaved_fp16() : GemmImplementation<__fp16, __fp16>(GemmMethod::GEMM_INTERLEAVED_FP16) { } -}; - -#elif defined(__aarch64__) - -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || defined(FP16_KERNELS) -class GemmImpl_gemm_fp16_interleaved_fp16 : public GemmImplementation<__fp16, __fp16> { -public: +static const GemmImplementation<__fp16, __fp16> gemm_fp16_methods[] = { +#if defined(__ARM_FEATURE_SVE) +{ + GemmMethod::GEMM_INTERLEAVED, + "interleaved_fp16_mla_3VLx8", + [](const GemmArgs<__fp16> &args) { return (args._Ksize > 4); }, + [](const GemmArgs<__fp16> &args) { return true; }, + [](const GemmArgs<__fp16> &args) { return new GemmInterleaved<interleaved_fp16_mla_3VLx8, __fp16, __fp16>(args); } +}, +#endif +#if defined(__aarch64__) && (defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || defined(FP16_KERNELS)) +{ + GemmMethod::GEMM_INTERLEAVED, + "hgemm_24x8", + [](const GemmArgs<__fp16> &args) { #ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - bool is_supported(const GemmArgs<__fp16> &args) override { return args._ci->has_fp16(); - } -#endif - - UniqueGemmCommon<__fp16, __fp16> instantiate(const GemmArgs<__fp16> &args) override { - return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<hgemm_24x8, __fp16, __fp16>(args)); - } - - GemmImpl_gemm_fp16_interleaved_fp16() : GemmImplementation<__fp16, __fp16>(GemmMethod::GEMM_INTERLEAVED_FP16) { } -}; -#endif - -#endif // __aarch64__ - -class GemmImpl_gemm_fp16_interleaved : public GemmImplementation<__fp16, __fp16> { -public: - UniqueGemmCommon<__fp16, __fp16> instantiate(const GemmArgs<__fp16> &args) override { -#ifdef __aarch64__ - return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_12x8, __fp16, __fp16>(args)); -#elif defined(__arm__) - return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_8x6, __fp16, __fp16>(args)); #else -# error Unknown Architecture + return true; #endif - } - - GemmImpl_gemm_fp16_interleaved() : GemmImplementation<__fp16, __fp16>(GemmMethod::GEMM_INTERLEAVED) { } -}; - -#if defined(__aarch64__) && (defined(__ARM_FEATURE_VECTOR_ARITHMETIC) || defined(FP16_KERNELS) || defined(__ARM_FEATURE_SVE)) -static GemmImpl_gemm_fp16_interleaved_fp16 gemm_fp16_interleaved_fp16_impl{}; -#endif -static GemmImpl_gemm_fp16_interleaved gemm_fp16_interleaved_impl{}; - -static std::vector<GemmImplementation<__fp16, __fp16> *> gemm_fp16_methods = { -#if defined(__aarch64__) && (defined(__ARM_FEATURE_VECTOR_ARITHMETIC) || defined(FP16_KERNELS) || defined(__ARM_FEATURE_SVE)) - &gemm_fp16_interleaved_fp16_impl, + }, + [](const GemmArgs<__fp16> &args) { return true; }, + [](const GemmArgs<__fp16> &args) { return new GemmInterleaved<hgemm_24x8, __fp16, __fp16>(args); } +}, #endif - &gemm_fp16_interleaved_impl +{ + GemmMethod::DEFAULT, + "", + nullptr, + nullptr, + nullptr, +} }; template<> -std::vector<GemmImplementation<__fp16, __fp16> *> &gemm_implementation_list<__fp16, __fp16>() { +const GemmImplementation<__fp16, __fp16> *gemm_implementation_list<__fp16, __fp16>() { return gemm_fp16_methods; } /* Explicitly instantiate the external functions for these types. */ -template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16>(GemmArgs<__fp16> &args, GemmConfig *cfg); -template GemmMethod get_gemm_method<__fp16, __fp16>(GemmArgs<__fp16> &args); -template bool method_is_compatible<__fp16, __fp16>(GemmMethod method, GemmArgs<__fp16> &args); +template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16>(const GemmArgs<__fp16> &args); +template KernelDescription get_gemm_method<__fp16, __fp16>(const GemmArgs<__fp16> &args); +template bool method_is_compatible<__fp16, __fp16>(GemmMethod method, const GemmArgs<__fp16> &args); +template std::vector<std::string> get_compatible_kernels<__fp16, __fp16> (const GemmArgs<__fp16> &args); } // namespace arm_gemm -#endif // __ARM_FP16_ARGS +#endif // __ARM_FP16_ARGS
\ No newline at end of file |