diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp | 2 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp | 12 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp | 24 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_int8.cpp | 2 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp | 45 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp | 2 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp | 2 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp | 6 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp | 6 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24.hpp | 6 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp16_24x8.hpp | 38 | ||||
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/transform.cpp | 8 |
12 files changed, 92 insertions, 61 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp index 5c08e6137d..0ddca04846 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp @@ -86,7 +86,7 @@ static const GemmImplementation<bfloat16, float> gemm_bf16_methods[] = "sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL", [](const GemmArgs &args) { return args._ci->has_sme2(); }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL, bfloat16, float>(args); } }, { diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp index 3b444ae333..c7adf8e4ac 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp @@ -69,19 +69,19 @@ static const GemmImplementation<__fp16, __fp16> gemm_fp16_methods[] = { }, { GemmMethod::GEMM_INTERLEAVED, - "sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL", + "sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL", [](const GemmArgs &args) { return args._ci->has_sme2(); }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); - return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, - [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL, __fp16, __fp16, Nothing, false, false, false, true>(args); } + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL, __fp16, __fp16, Nothing, false, false, false, true>(args); } }, { GemmMethod::GEMM_INTERLEAVED, - "sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL", + "sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL", [](const GemmArgs &args) { return args._ci->has_sme2(); }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, - [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL, __fp16, __fp16, Nothing, false, false, false, true>(args); } + return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, + [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL, __fp16, __fp16, Nothing, false, false, false, true>(args); } }, { GemmMethod::GEMM_INTERLEAVED, diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index 290fe87230..0c1d3a387b 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -124,14 +124,14 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_HYBRID, "sme2_gemv_fp32bf16fp32_dot_16VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemvPretransposed<cls_sme2_gemv_fp32bf16fp32_dot_16VL, float, float>(args); } }, { GemmMethod::GEMM_HYBRID, "sme2_gemv_fp32_mla_16VL", - [](const GemmArgs &args) { return args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; }, + [](const GemmArgs &args) { return args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemvPretransposed<cls_sme2_gemv_fp32_mla_16VL, float, float>(args); } }, @@ -139,25 +139,25 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL, float, float>(args); } }, #endif // ARM_COMPUTE_ENABLE_BF16 { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_fp32_mopa_1VLx4VL", - [](const GemmArgs &args) { return args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_1VLx4VL, float, float>(args); } }, #ifdef ARM_COMPUTE_ENABLE_BF16 { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL, float, float>(args); } @@ -166,7 +166,7 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_fp32_mopa_4VLx1VL", - [](const GemmArgs &args) { return args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>(); return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_4VLx1VL, float, float>(args); } @@ -175,7 +175,7 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL, float, float>(args); } }, @@ -183,7 +183,7 @@ GemmImplementation<float, float>::with_estimate( { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_fp32_mopa_2VLx2VL", - [](const GemmArgs &args) { return args._ci->has_sme2(); }, + [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; }, nullptr, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_2VLx2VL, float, float>(args); } }, @@ -199,14 +199,14 @@ GemmImplementation<float, float>::with_estimate( GemmImplementation<float, float>::with_estimate( GemmMethod::GEMM_HYBRID, "sve_hybrid_fp32bf16fp32_mmla_6x4VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); }, [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_6x4VL, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_6x4VL, float, float>(args); } ), GemmImplementation<float, float>::with_estimate( GemmMethod::GEMM_HYBRID, "sve_hybrid_fp32bf16fp32_mmla_4x6VL", - [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); }, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); }, [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_4x6VL, float, float>::estimate_cycles<float>(args); }, [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_4x6VL, float, float>(args); } ), diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp index 0dc0d55b27..fedda3a47a 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp @@ -63,7 +63,7 @@ static const GemmImplementation<int8_t, int32_t> gemm_s8_methods[] = { "sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL", [](const GemmArgs &args) { return args._ci->has_sme2(); }, [](const GemmArgs &args) { const auto VL = sme::get_vector_length<int32_t>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL, int8_t, int32_t>(args); } }, { diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp index ae344f09b5..897ec9d05f 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp @@ -190,10 +190,19 @@ void kernel_and_merge<false, false, Requantize32>::run( auto p=prof.ScopedProfiler(PROFILE_KERNEL, (m_max - m_0) * (n_max - n_0) * kern_k); #endif + // Offset C pointer in a similar way to non-quantized case above. + Tri *offset_c_ptr; + + if (c_ptr == nullptr) { + offset_c_ptr = nullptr; + } else { + offset_c_ptr = c_ptr + m_0 * ldc + n_0; + } + strat.kernel(// A and B pointers are just the packed panels. a_ptr, b_panel, // Provide relevant part of output array and row stride. - c_ptr + m_0 * ldc + n_0, ldc, + offset_c_ptr, ldc, // M, N, K sizes m_max-m_0, n_max - n_0, kern_k, // Bias, activation, accumulation. Need to offset the bias as needed. @@ -663,15 +672,27 @@ class GemmInterleaved : public GemmCommon<To, Tr> { return roundup(args._cfg->inner_block_size, strategy::k_unroll()); } - // K blocking not supported if we are requantizing. - if (std::is_same<OutputStage, Requantize32>::value) { + // K blocking not supported if we are requantizing with the merging + // kernels. + if (std::is_same<OutputStage, Requantize32>::value && MergeStep) { return get_ktotal(args); } + const unsigned int L1_size = args._ci->get_L1_cache_size(); + // Special blocking for SME if (is_sme<strategy>::value) { - // Don't bother to block below this size threshold, experimentally determined to be 320 for FP32 - unsigned int scaling_threshold = 1280 / sizeof(Toi); + // Target 512 bytes for 64kB L1, or 1024 bytes for 128kB L1. + unsigned int target_bytes_per_block = L1_size / 128; + + // Default cache size in gemm-linux is 32kB though - so make + // sure minimum is 512 + if (target_bytes_per_block < 512) { + target_bytes_per_block = 512; + } + + // Don't bother to block below this size threshold (1.25X target size) + unsigned int scaling_threshold = ((target_bytes_per_block * 5) / 4) / sizeof(Toi); if (get_ktotal(args) <= scaling_threshold) { return get_ktotal(args); @@ -679,7 +700,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> { // Once we are blocking, this (lower) threshold determines when we should use more blocks // NOTE: Could be that some factor-based solution would work better here. - unsigned int max_block_size = 1024 / sizeof(Toi); + unsigned int max_block_size = target_bytes_per_block / sizeof(Toi); unsigned int num_k_blocks = iceildiv(get_ktotal(args), max_block_size); @@ -688,7 +709,6 @@ class GemmInterleaved : public GemmCommon<To, Tr> { return k_block; } - const unsigned int L1_size = args._ci->get_L1_cache_size(); unsigned int k_block; // k_block: Find out how much of the larger array can be loaded into half the cache. @@ -723,6 +743,17 @@ class GemmInterleaved : public GemmCommon<To, Tr> { return roundup(args._cfg->outer_block_size, strategy::out_width()); } + // Special blocking for SME + if (is_sme<strategy>::value) { + // If total width is less than 4x kernel width, return the entire width. + if (args._Nsize < strategy::out_width()*4) { + return roundup(args._Nsize, strategy::out_width()); + } + + // Otherwise block to single kernel width. + return strategy::out_width(); + } + unsigned int x_block; const unsigned int L2_size = args._ci->get_L2_cache_size(); const unsigned int k_block = get_k_block_size(args); diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp index d1c4e49edb..321c97262f 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp @@ -82,7 +82,7 @@ static const GemmImplementation<int8_t, int8_t, Requantize32> gemm_qint8_methods "sme2_interleaved_nomerge_s8q_mopa_1VLx4VL", [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, [](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length<int32_t>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_s8q_mopa_1VLx4VL, int8_t, int8_t>(args, qp); } }, { diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp index b85b1c4fcf..93eecf991e 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp @@ -78,7 +78,7 @@ static const GemmImplementation<uint8_t, uint8_t, Requantize32> gemm_quint8_meth "sme2_interleaved_nomerge_u8q_mopa_1VLx4VL", [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));}, [](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length<uint32_t>(); - return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, + return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_u8q_mopa_1VLx4VL, uint8_t, uint8_t>(args, qp); } }, { diff --git a/src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp index 782399df8c..38d9b763f6 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_s8fp32.cpp @@ -55,7 +55,7 @@ static const GemmImplementation<int8_t, float, DequantizeFloat> gemm_s8fp32_meth { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL.hpp", - [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2(); }, + [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args, const DequantizeFloat &) { const auto VL = sme::get_vector_length<float>(); return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); }, [](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized<cls_sme2_interleaved_nomerge_s8qfp32_mopa_1VLx4VL, int8_t, float>(args, dq); } @@ -63,7 +63,7 @@ static const GemmImplementation<int8_t, float, DequantizeFloat> gemm_s8fp32_meth { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_s8qfp32_mopa_4Vx1VL.hpp", - [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2(); }, + [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2() && !args._accumulate; }, [](const GemmArgs &args, const DequantizeFloat &) { const auto VL = sme::get_vector_length<float>(); return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); }, [](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized<cls_sme2_interleaved_nomerge_s8qfp32_mopa_4VLx1VL, int8_t, float>(args, dq); } @@ -71,7 +71,7 @@ static const GemmImplementation<int8_t, float, DequantizeFloat> gemm_s8fp32_meth { GemmMethod::GEMM_INTERLEAVED, "sme2_interleaved_nomerge_s8qfp32_mopa_2Vx2VL.hpp", - [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2(); }, + [](const GemmArgs &args, const DequantizeFloat &) { return args._ci->has_sme2() && !args._accumulate; }, nullptr, [](const GemmArgs &args, const DequantizeFloat &dq) { return new GemmInterleavedNoMergeDequantized<cls_sme2_interleaved_nomerge_s8qfp32_mopa_2VLx2VL, int8_t, float>(args, dq); } }, diff --git a/src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp b/src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp index 59591935cd..7c09608e3e 100644 --- a/src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp +++ b/src/core/NEON/kernels/arm_gemm/interleave_indirect.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022 Arm Limited. + * Copyright (c) 2020-2022, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -330,11 +330,11 @@ template void Interleave<8, 2, VLType::None>(float *, const float *, size_t, uns #endif // ARM_COMPUTE_ENABLE_SVE && ARM_COMPUTE_ENABLE_SVEF32MM /* FP16 */ -#if defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#if defined(FP16_KERNELS) || defined(ARM_COMPUTE_ENABLE_FP16) template void IndirectInterleave<8, 1, VLType::None>(__fp16 *, const __fp16 * const * const *, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); template void ConvolutionInterleave<8, 1, VLType::None>(__fp16 *, const __fp16 *, size_t, const convolver<__fp16> &, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); template void Interleave<8, 1, VLType::None>(__fp16 *, const __fp16 *, size_t, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); -#endif // FP16_KERNELS ar __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // FP16_KERNELS ar ARM_COMPUTE_ENABLE_FP16 template void IndirectInterleave<8, 1, VLType::None>(float *, const __fp16 * const * const *, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); template void ConvolutionInterleave<8, 1, VLType::None>(float *, const __fp16 *, size_t, const convolver<__fp16> &, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, bool, int32_t); diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24.hpp index 586d6a64a4..d9668aae02 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2021, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,7 +23,7 @@ */ #pragma once -#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)) +#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(ARM_COMPUTE_ENABLE_FP16)) #include "../performance_parameters.hpp" #include "../std_transforms_fixed.hpp" @@ -89,4 +89,4 @@ public: } // namespace arm_gemm -#endif // __aarch64__ && (FP16_KERNELS || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#endif // __aarch64__ && (FP16_KERNELS || ARM_COMPUTE_ENABLE_FP16) diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp16_24x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp16_24x8.hpp index a81d4504ae..ba47e0aa54 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp16_24x8.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_fp16_24x8.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020 Arm Limited. + * Copyright (c) 2019-2020, 2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,7 +23,7 @@ */ #pragma once -#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)) +#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(ARM_COMPUTE_ENABLE_FP16)) template<> void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const __fp16 *bias, Activation act, bool append) @@ -86,7 +86,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -140,7 +140,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -217,7 +217,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -317,7 +317,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -439,7 +439,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -584,7 +584,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -752,7 +752,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -944,7 +944,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1150,7 +1150,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1204,7 +1204,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1278,7 +1278,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1372,7 +1372,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1485,7 +1485,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1618,7 +1618,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1771,7 +1771,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -1945,7 +1945,7 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } else { /* Optimized routine to copy an entire block */ __asm __volatile ( -#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#ifndef ARM_COMPUTE_ENABLE_FP16 ".arch armv8.2-a+fp16\n" #endif "dup v0.8h, %[maxval].h[0]\n" @@ -2112,4 +2112,4 @@ void MergeResults<24, 8, false>(__fp16 *out, const __fp16 *in, const int ldout, } } -#endif // __aarch64__ && (FP16_KERNELS || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#endif // __aarch64__ && (FP16_KERNELS || ARM_COMPUTE_ENABLE_FP16) diff --git a/src/core/NEON/kernels/arm_gemm/transform.cpp b/src/core/NEON/kernels/arm_gemm/transform.cpp index 45e4f0e1de..06d9e2416c 100644 --- a/src/core/NEON/kernels/arm_gemm/transform.cpp +++ b/src/core/NEON/kernels/arm_gemm/transform.cpp @@ -129,17 +129,17 @@ void Transform( // We don't have assembler transforms for AArch32, generate templated ones here. #ifdef __arm__ template void Transform<8, 1, true, VLType::None>(float *, const float *, int, int, int, int, int); -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#if defined(ARM_COMPUTE_ENABLE_FP16) template void Transform<8, 1, true, VLType::None>(float *, const __fp16 *, int, int, int, int, int); -#endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#endif // defined(ARM_COMPUTE_ENABLE_FP16) #ifdef ARM_COMPUTE_ENABLE_BF16 template void Transform<8, 1, true, VLType::None>(float *, const bfloat16 *, int, int, int, int, int); #endif // ARM_COMPUTE_ENABLE_BF16 #endif // AArch32 -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#if defined(ARM_COMPUTE_ENABLE_FP16) template void Transform<12, 1, false, VLType::None>(float *, const __fp16 *, int, int, int, int, int); -#endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +#endif // defined(ARM_COMPUTE_ENABLE_FP16) #ifdef ARM_COMPUTE_ENABLE_BF16 template void Transform<12, 1, false, VLType::None>(float *, const bfloat16 *, int, int, int, int, int); #endif // ARM_COMPUTE_ENABLE_BF16 |