diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-10-14 19:03:09 +0100 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-10-23 12:08:12 +0000 |
commit | 48b3ef89de5f21a0169d8416e3d54081f82c7bf8 (patch) | |
tree | f857d733ccf446c704823dc7ac796a96eb55095e /src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp | |
parent | 1dce3101ef8d77c8cf0af7dfd4af6595a0136b91 (diff) | |
download | ComputeLibrary-48b3ef89de5f21a0169d8416e3d54081f82c7bf8.tar.gz |
COMPMID-2577: Fuse bias addition and activation in gemm assembly kernels
Change-Id: I7f52112d2d05b1ea3d3f3d4b19b8eafab05d6c44
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-on: https://review.mlplatform.org/c/2141
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Pablo Marquez <pablo.tello@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp | 90 |
1 files changed, 41 insertions, 49 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index dedcdb7655..cf91ee0652 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -33,17 +33,16 @@ #include "kernels/a32_sgemm_8x6.hpp" #include "kernels/a64_hybrid_fp32_mla_16x4.hpp" +#include "kernels/a64_native_fp32_mla_16x4.hpp" +#include "kernels/a64_smallK_hybrid_fp32_mla_4x6.hpp" +#include "kernels/a64_smallK_hybrid_fp32_mla_4x8.hpp" #include "kernels/a64_sgemm_12x8.hpp" -#include "kernels/a64_sgemm_native_16x4.hpp" -#include "kernels/a64_sgemm_nativeA_pretransposeB_16x4.hpp" #include "kernels/a64_sgemv_pretransposed.hpp" #include "kernels/a64_sgemv_trans.hpp" #include "kernels/sve_hybrid_fp32_mla_4VLx4.hpp" #include "kernels/sve_interleaved_fp32_mla_3VLx8.hpp" #include "kernels/sve_native_fp32_mla_4VLx4.hpp" -#include "kernels/sve_smallK_fp32_mla_1VLx4.hpp" -#include "kernels/sve_smallK_hybrid_fp32_mla_1VLx4.hpp" namespace arm_gemm { @@ -52,88 +51,81 @@ static const GemmImplementation<float, float> gemm_fp32_methods[] = { GemmMethod::GEMV_BATCHED, "gemv_batched", - [](const GemmArgs<float> &args) { return (args._Msize==1) && (args._nbatches>1); }, + [](const GemmArgs &args) { return (args._Msize==1) && (args._nbatches>1); }, nullptr, - [](const GemmArgs<float> &args) { return new GemvBatched<float, float>(args); } + [](const GemmArgs &args) { return new GemvBatched<float, float>(args); } }, #ifdef __aarch64__ { GemmMethod::GEMV_PRETRANSPOSED, "sgemv_pretransposed", - [](const GemmArgs<float> &args) { return (args._Msize==1 && args._alpha==1.0f && args._pretransposed_hint && args._nbatches==1); }, + [](const GemmArgs &args) { return (args._Msize==1 && args._pretransposed_hint && args._nbatches==1); }, nullptr, - [](const GemmArgs<float> &args) { return new GemvPretransposed<sgemv_pretransposed, float, float>(args); } + [](const GemmArgs &args) { return new GemvPretransposed<sgemv_pretransposed, float, float>(args); } }, { GemmMethod::GEMV_NATIVE_TRANSPOSED, "sgemv_trans", - [](const GemmArgs<float> &args) { return (args._Msize==1 && args._alpha==1.0f && !args._trA && !args._trB && args._nbatches==1); }, + [](const GemmArgs &args) { return (args._Msize==1 && !args._trA && !args._trB && args._nbatches==1); }, nullptr, - [](const GemmArgs<float> &args) { return new GemvNativeTransposed<sgemv_trans, float, float>(args); } + [](const GemmArgs &args) { return new GemvNativeTransposed<sgemv_trans, float, float>(args); } }, #ifdef __ARM_FEATURE_SVE -// SVE smallk / native / hybrid methods -{ - GemmMethod::GEMM_HYBRID, - "smallK_hybrid_fp32_mla_1VLx4", - [](const GemmArgs<float> &args) { return (args._Ksize <= 24) && !args._trA && args._alpha==1.0f && args._pretransposed_hint; }, - nullptr, - [](const GemmArgs<float> &args) { return new GemmHybrid<smallK_hybrid_fp32_mla_1VLx4, float, float>(args); } -}, +// SVE native / hybrid methods { GemmMethod::GEMM_HYBRID, "hybrid_fp32_mla_4VLx4", - [](const GemmArgs<float> &args) { return (args._Ksize >= 4) && (args._alpha == 1.0f) && !args._trA && args._pretransposed_hint; }, - [](const GemmArgs<float> &args) { return ((args._Ksize <= 256) && (args._Nsize <= 256)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); }, - [](const GemmArgs<float> &args) { return new GemmHybrid<hybrid_fp32_mla_4VLx4, float, float>(args); } -}, -{ - GemmMethod::GEMM_NATIVE, - "smallK_fp32_mla_1VLx4", - [](const GemmArgs<float> &args) { return (args._Ksize <= 24) && !args._trA && !args._trB && args._alpha==1.0f; }, - nullptr, - [](const GemmArgs<float> &args) { return new GemmNative<smallK_fp32_mla_1VLx4, float, float>(args); } + [](const GemmArgs &args) { return (args._Ksize >= 4) && !args._trA && args._pretransposed_hint; }, + [](const GemmArgs &args) { return ((args._Ksize <= 256) && (args._Nsize <= 256)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); }, + [](const GemmArgs &args) { return new GemmHybrid<hybrid_fp32_mla_4VLx4, float, float>(args); } }, { GemmMethod::GEMM_NATIVE, "native_fp32_mla_4VLx4", - [](const GemmArgs<float> &args) { return (args._Ksize>4 && args._alpha==1.0f && !args._trA && !args._trB); }, - [](const GemmArgs<float> &args) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); }, - [](const GemmArgs<float> &args) { return new GemmNative<native_fp32_mla_4VLx4, float, float>(args); } + [](const GemmArgs &args) { return (args._Ksize>4 && !args._trA && !args._trB); }, + [](const GemmArgs &args) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); }, + [](const GemmArgs &args) { return new GemmNative<native_fp32_mla_4VLx4, float, float>(args); } }, #endif // __ARM_FEATURE_SVE // NEON native / hybrid methods { GemmMethod::GEMM_HYBRID, - "sgemm_nativeA_pretransposeB_16x4", - [](const GemmArgs<float> &args) { return (args._Ksize >= 4) && (args._alpha == 1.0f) && !args._trA && args._pretransposed_hint; }, - [](const GemmArgs<float> &args) { return ((args._Ksize <= 256) && (args._Nsize <= 256)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); }, - [](const GemmArgs<float> &args) { return new GemmHybrid<sgemm_nativeA_pretransposeB_16x4, float, float>(args); } + "smallK_hybrid_fp32_mla_4x8", + [](const GemmArgs &args) { return (args._Ksize <= 8) && (args._Nsize % 4)==0 && !args._trA && args._pretransposed_hint; }, + nullptr, + [](const GemmArgs &args) { return new GemmHybrid<smallK_hybrid_fp32_mla_4x8, float, float>(args); } +}, +{ + GemmMethod::GEMM_HYBRID, + "smallK_hybrid_fp32_mla_4x6", + [](const GemmArgs &args) { return (args._Ksize > 8) && (args._Ksize <= 16) && (args._Nsize % 4)==0 && !args._trA && args._pretransposed_hint; }, + nullptr, + [](const GemmArgs &args) { return new GemmHybrid<smallK_hybrid_fp32_mla_4x6, float, float>(args); } }, { GemmMethod::GEMM_HYBRID, "hybrid_fp32_mla_16x4", - [](const GemmArgs<float> &args) { return (args._Ksize >= 4) && (args._alpha == 1.0f) && !args._trA && args._pretransposed_hint; }, - [](const GemmArgs<float> &args) { return ((args._Ksize <= 256) && (args._Nsize <= 256)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); }, - [](const GemmArgs<float> &args) { return new GemmHybrid<hybrid_fp32_mla_16x4, float, float>(args); } + [](const GemmArgs &args) { return (args._Ksize >= 4) && !args._trA && args._pretransposed_hint; }, + [](const GemmArgs &args) { return ((args._Ksize <= 256) && (args._Nsize <= 256)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); }, + [](const GemmArgs &args) { return new GemmHybrid<hybrid_fp32_mla_16x4, float, float>(args); } }, { GemmMethod::GEMM_NATIVE, - "sgemm_native_16x4", - [](const GemmArgs<float> &args) { return (args._Ksize>4 && (args._Nsize % 16)==0 && args._alpha==1.0f && !args._trA && !args._trB); }, - [](const GemmArgs<float> &args) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); }, - [](const GemmArgs<float> &args) { return new GemmNative<sgemm_native_16x4, float, float>(args); } + "native_fp32_mla_16x4", + [](const GemmArgs &args) { return (args._Ksize>4 && (args._Nsize % 16)==0 && !args._trA && !args._trB); }, + [](const GemmArgs &args) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); }, + [](const GemmArgs &args) { return new GemmNative<native_fp32_mla_16x4, float, float>(args); } }, #ifdef __ARM_FEATURE_SVE { GemmMethod::GEMM_INTERLEAVED, "interleaved_fp32_mla_3VLx8", - [](const GemmArgs<float> &args) { return (args._Ksize>4); }, + [](const GemmArgs &args) { return (args._Ksize>4); }, nullptr, - [](const GemmArgs<float> &args) { return new GemmInterleaved<interleaved_fp32_mla_3VLx8, float, float>(args); } + [](const GemmArgs &args) { return new GemmInterleaved<interleaved_fp32_mla_3VLx8, float, float>(args); } }, #endif // __ARM_FEATURE_SVE { @@ -141,7 +133,7 @@ static const GemmImplementation<float, float> gemm_fp32_methods[] = "sgemm_12x8", nullptr, nullptr, - [](const GemmArgs<float> &args) { return new GemmInterleaved<sgemm_12x8, float, float>(args); } + [](const GemmArgs &args) { return new GemmInterleaved<sgemm_12x8, float, float>(args); } }, #endif // __aarch64__ @@ -151,7 +143,7 @@ static const GemmImplementation<float, float> gemm_fp32_methods[] = "sgemm_8x6", nullptr, nullptr, - [](const GemmArgs<float> &args) { return new GemmInterleaved<sgemm_8x6, float, float>(args); } + [](const GemmArgs &args) { return new GemmInterleaved<sgemm_8x6, float, float>(args); } }, #endif // __arm__ { @@ -170,8 +162,8 @@ const GemmImplementation<float, float> *gemm_implementation_list<float, float>() } /* Explicitly instantiate the external functions for these types. */ -template UniqueGemmCommon<float, float> gemm<float, float, Nothing>(const GemmArgs<float> &args, const Nothing &); -template KernelDescription get_gemm_method<float, float, Nothing>(const GemmArgs<float> &args, const Nothing &); -template std::vector<KernelDescription> get_compatible_kernels<float, float, Nothing> (const GemmArgs<float> &args, const Nothing &); +template UniqueGemmCommon<float, float> gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &); +template KernelDescription get_gemm_method<float, float, Nothing>(const GemmArgs &args, const Nothing &); +template std::vector<KernelDescription> get_compatible_kernels<float, float, Nothing> (const GemmArgs &args, const Nothing &); } // namespace arm_gemm |