diff options
author | Pablo Marquez Tello <pablo.tello@arm.com> | 2024-05-14 07:54:19 +0100 |
---|---|---|
committer | Pablo Marquez Tello <pablo.tello@arm.com> | 2024-05-14 11:58:27 +0000 |
commit | 2217f1e60964fe586cae7ef996af7ef1c0bef2ab (patch) | |
tree | 593dad484fe922efeb655335d37b0066d2af4bcd | |
parent | 21fb2ad16a30a5ff29929515abe28c14b2c6b5a1 (diff) | |
download | ComputeLibrary-2217f1e60964fe586cae7ef996af7ef1c0bef2ab.tar.gz |
Refactor arm_gemm to enable FP16 in all multi_isa builds
* Resolves MLCE-1285
Change-Id: I22a37972aefe1c0f04accbc798baa18358ed8959
Signed-off-by: Pablo Marquez Tello <pablo.tello@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11552
Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r-- | filelist.json | 15 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp | 6 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24.hpp | 6 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp16_24x8.hpp | 38 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/transform.cpp | 8 | ||||
-rw-r--r-- | src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp | 10 |
6 files changed, 43 insertions, 40 deletions
diff --git a/filelist.json b/filelist.json index 77656bcab8..5246f27f68 100644 --- a/filelist.json +++ b/filelist.json @@ -1593,7 +1593,6 @@ "neon": { "common": [ "src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp", - "src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp", "src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp", "src/core/NEON/kernels/arm_gemm/gemm_bf16bf16.cpp", "src/core/NEON/kernels/arm_gemm/gemm_int16.cpp", @@ -1605,7 +1604,6 @@ "src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp", "src/core/NEON/kernels/arm_gemm/interleave-8way.cpp", "src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp", - "src/core/NEON/kernels/arm_gemm/mergeresults-fp16.cpp", "src/core/NEON/kernels/arm_gemm/mergeresults.cpp", "src/core/NEON/kernels/arm_gemm/misc.cpp", "src/core/NEON/kernels/arm_gemm/quantized.cpp", @@ -1622,13 +1620,8 @@ "src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_8x12/a55r1.cpp", "src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_8x12/generic.cpp", "src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_8x12/x1.cpp", - "src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24/a55r1.cpp", - "src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24/generic.cpp", - "src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24/x1.cpp", "src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_bf16fp32_dot_6x16/generic.cpp", "src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_bf16fp32_mmla_6x16/generic.cpp", - "src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp16_mla_6x32/a55.cpp", - "src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp16_mla_6x32/generic.cpp", "src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x24/a55.cpp", "src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_4x24/generic.cpp", "src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_6x16/a55.cpp", @@ -1682,6 +1675,13 @@ "fp32":["src/cpu/kernels/gemm_matrix_mul/generic/neon/fp32.cpp", "src/cpu/kernels/gemm_matrix_add/generic/neon/fp32.cpp"], "fp16":["src/cpu/kernels/gemm_matrix_mul/generic/neon/fp16.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24/a55r1.cpp", + "src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp", + "src/core/NEON/kernels/arm_gemm/mergeresults-fp16.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp16_mla_6x32/a55.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp16_mla_6x32/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24/x1.cpp", "src/cpu/kernels/gemm_matrix_add/generic/neon/fp16.cpp"], "estate32": [ "src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a53.cpp", @@ -1690,6 +1690,7 @@ ], "estate64": [ "src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed/generic.cpp" + ], "fixed_format_kernels": [ "src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_bf16fp32_mmla_6x16/generic.cpp", diff --git a/src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp b/src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp index 59591935cd..7c09608e3e 100644 --- a/src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp +++ b/src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022 Arm Limited. + * Copyright (c) 2020-2022, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -330,11 +330,11 @@ template void Interleave<8, 2, VLType::None>(float *, const float *, size_t, uns #endif // ARM_COMPUTE_ENABLE_SVE && ARM_COMPUTE_ENABLE_SVEF32MM /* FP16 */ -#if defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#if defined(FP16_KERNELS) || defined(ARM_COMPUTE_ENABLE_FP16) template void IndirectInterleave<8, 1, VLType::None>(__fp16 *, const __fp16 * const * const *, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); template void ConvolutionInterleave<8, 1, VLType::None>(__fp16 *, const __fp16 *, size_t, const convolver<__fp16> &, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); template void Interleave<8, 1, VLType::None>(__fp16 *, const __fp16 *, size_t, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); -#endif // FP16_KERNELS ar __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // FP16_KERNELS ar ARM_COMPUTE_ENABLE_FP16 template void IndirectInterleave<8, 1, VLType::None>(float *, const __fp16 * const * const *, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); template void ConvolutionInterleave<8, 1, VLType::None>(float *, const __fp16 *, size_t, const convolver<__fp16> &, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24.hpp index 586d6a64a4..d9668aae02 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2021, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,7 +23,7 @@ */ #pragma once -#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)) +#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(ARM_COMPUTE_ENABLE_FP16)) #include "../performance_parameters.hpp" #include "../std_transforms_fixed.hpp" @@ -89,4 +89,4 @@ public: } // namespace arm_gemm -#endif // __aarch64__ && (FP16_KERNELS || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#endif // __aarch64__ && (FP16_KERNELS || ARM_COMPUTE_ENABLE_FP16) diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp16_24x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp16_24x8.hpp index a81d4504ae..ba47e0aa54 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp16_24x8.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp16_24x8.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020 Arm Limited. + * Copyright (c) 2019-2020, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,7 +23,7 @@ */ #pragma once -#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)) +#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(ARM_COMPUTE_ENABLE_FP16)) template<> void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const __fp16 *bias, Activation act, bool append) @@ -86,7 +86,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -140,7 +140,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -217,7 +217,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -317,7 +317,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -439,7 +439,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -584,7 +584,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -752,7 +752,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -944,7 +944,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1150,7 +1150,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1204,7 +1204,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1278,7 +1278,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1372,7 +1372,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1485,7 +1485,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1618,7 +1618,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1771,7 +1771,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1945,7 +1945,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -2112,4 +2112,4 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } } -#endif // __aarch64__ && (FP16_KERNELS || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#endif // __aarch64__ && (FP16_KERNELS || ARM_COMPUTE_ENABLE_FP16) diff --git a/src/core/NEON/kernels/arm_gemm/transform.cpp b/src/core/NEON/kernels/arm_gemm/transform.cpp index 45e4f0e1de..06d9e2416c 100644 --- a/src/core/NEON/kernels/arm_gemm/transform.cpp +++ b/src/core/NEON/kernels/arm_gemm/transform.cpp @@ -129,17 +129,17 @@ void Transform( // We don't have assembler transforms for AArch32, generate templated ones here. #ifdef __arm__ template void Transform<8, 1, true, VLType::None>(float *, const float *, int, int, int, int, int); -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#if defined(ARM_COMPUTE_ENABLE_FP16) template void Transform<8, 1, true, VLType::None>(float *, const __fp16 *, int, int, int, int, int); -#endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#endif // defined(ARM_COMPUTE_ENABLE_FP16) #ifdef ARM_COMPUTE_ENABLE_BF16 template void Transform<8, 1, true, VLType::None>(float *, const bfloat16 *, int, int, int, int, int); #endif // ARM_COMPUTE_ENABLE_BF16 #endif // AArch32 -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#if defined(ARM_COMPUTE_ENABLE_FP16) template void Transform<12, 1, false, VLType::None>(float *, const __fp16 *, int, int, int, int, int); -#endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#endif // defined(ARM_COMPUTE_ENABLE_FP16) #ifdef ARM_COMPUTE_ENABLE_BF16 template void Transform<12, 1, false, VLType::None>(float *, const bfloat16 *, int, int, int, int, int); #endif // ARM_COMPUTE_ENABLE_BF16 diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index 7d85885654..cc21ccbaa1 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -945,6 +945,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected } break; #endif /* __aarch64__ */ + #if defined(ARM_COMPUTE_ENABLE_BF16) case DataType::BFLOAT16: { @@ -963,13 +964,14 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected break; } #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +#if defined(ARM_COMPUTE_ENABLE_FP16) case DataType::F16: ARM_COMPUTE_RETURN_ERROR_ON_MSG( !(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for F16 input and F16 output"); break; -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ +#endif /* ARM_COMPUTE_ENABLE_FP16 */ default: ARM_COMPUTE_RETURN_ERROR_ON_MSG(true, "Usupported type. Could not find a kernel"); break; @@ -1102,11 +1104,11 @@ void CpuGemmAssemblyDispatch::configure( } break; #endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifdef ARM_COMPUTE_ENABLE_FP16 case DataType::F16: create_arm_gemm<float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info); break; -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ +#endif /* ARM_COMPUTE_ENABLE_FP16 */ default: break; } |