diff options
Diffstat (limited to 'src/cpu/kernels/CpuSoftmaxKernel.cpp')
-rw-r--r-- | src/cpu/kernels/CpuSoftmaxKernel.cpp | 39 |
1 files changed, 20 insertions, 19 deletions
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.cpp b/src/cpu/kernels/CpuSoftmaxKernel.cpp index cbf3773ddc..054adfa23c 100644 --- a/src/cpu/kernels/CpuSoftmaxKernel.cpp +++ b/src/cpu/kernels/CpuSoftmaxKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -34,8 +34,7 @@ #include "src/core/helpers/WindowHelpers.h" #include "src/core/common/Registrars.h" -#include "src/cpu/kernels/softmax/impl/neon/list.h" -#include "src/cpu/kernels/softmax/impl/sve/list.h" +#include "src/cpu/kernels/softmax/list.h" namespace arm_compute { @@ -74,12 +73,12 @@ static const SoftmaxLogits1DKernel available_logits_1d_kernels[] = { "sve_fp32_softmax_logits_1d", [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32) && data.ci.has_sve(); }, - REGISTER_FP32_SVE(arm_compute::cpu::sve_softmax_logits_1d_float<float>) + REGISTER_FP32_SVE(arm_compute::cpu::sve_fp32_softmax) }, { "sve_fp16_softmax_logits_1d", [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F16) && data.ci.has_sve(); }, - REGISTER_FP16_SVE(arm_compute::cpu::sve_softmax_logits_1d_float<float16_t>) + REGISTER_FP16_SVE(arm_compute::cpu::sve_fp16_softmax) }, #endif /* defined(ARM_COMPUTE_ENABLE_SVE) */ @@ -87,13 +86,13 @@ static const SoftmaxLogits1DKernel available_logits_1d_kernels[] = { "neon_fp32_softmax_logits_1d", [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32); }, - REGISTER_FP32_NEON(arm_compute::cpu::neon_softmax_logits_1d_float<float>) + REGISTER_FP32_NEON(arm_compute::cpu::neon_fp32_softmax) }, #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) { "neon_fp16_softmax_logits_1d", [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F16); }, - REGISTER_FP16_NEON(arm_compute::cpu::neon_softmax_logits_1d_float<float16_t>) + REGISTER_FP16_NEON(arm_compute::cpu::neon_fp16_softmax) }, #endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */ #endif /* defined(ARM_COMPUTE_ENABLE_NEON) */ @@ -102,24 +101,26 @@ static const SoftmaxLogits1DKernel available_logits_1d_kernels[] = { "sve2_qu8_softmax_logits_1d", [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8) && data.ci.has_sve2(); }, - REGISTER_QASYMM8_SVE(arm_compute::cpu::sve_softmax_logits_1d_quantized<qasymm8_t>) + REGISTER_QASYMM8_SVE2(arm_compute::cpu::sve2_qasymm8_softmax) }, { "sve2_qs8_softmax_logits_1d", [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED) && data.ci.has_sve2(); }, - REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::sve_softmax_logits_1d_quantized<qasymm8_signed_t>) + REGISTER_QASYMM8_SIGNED_SVE2(arm_compute::cpu::sve2_qasymm8_signed_softmax) }, #endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */ +#if defined(ARM_COMPUTE_ENABLE_NEON) { "neon_qu8_softmax_logits_1d", [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8); }, - REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_softmax_logits_1d_quantized<qasymm8_t>) + REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_qasymm8_softmax) }, { "neon_qs8_softmax_logits_1d", [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); }, - REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_softmax_logits_1d_quantized<qasymm8_signed_t>) + REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_qasymm8_signed_softmax) }, +#endif //defined(ARM_COMPUTE_ENABLE_NEON) }; static const SoftmaxLogits1DMaxKernel available_logits_1d_max_kernels[] = @@ -128,46 +129,46 @@ static const SoftmaxLogits1DMaxKernel available_logits_1d_max_kernels[] = { "sve_fp32_logits_1d_max", [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32) && data.ci.has_sve(); }, - REGISTER_FP32_SVE(arm_compute::cpu::sve_logits_1d_max<float>) + REGISTER_FP32_SVE(arm_compute::cpu::sve_fp32_logits) }, { "sve_fp16_logits_1d_max", [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F16) && data.ci.has_sve(); }, - REGISTER_FP16_SVE(arm_compute::cpu::sve_logits_1d_max<float16_t>) + REGISTER_FP16_SVE(arm_compute::cpu::sve_fp16_logits) }, { "sve_qu8_logits_1d_max", [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8) && data.ci.has_sve(); }, - REGISTER_QASYMM8_SVE(arm_compute::cpu::sve_logits_1d_max<qasymm8_t>) + REGISTER_QASYMM8_SVE(arm_compute::cpu::sve_qasymm8_logits) }, { "sve_qs8_logits_1d_max", [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED) && data.ci.has_sve(); }, - REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::sve_logits_1d_max<qasymm8_signed_t>) + REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::sve_qasymm8_signed_logits) }, #endif /* defined(ARM_COMPUTE_ENABLE_SVE) */ #if defined(ARM_COMPUTE_ENABLE_NEON) { "neon_fp32_logits_1d_max", [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32); }, - REGISTER_FP32_NEON(arm_compute::cpu::neon_logits_1d_max<float>) + REGISTER_FP32_NEON(arm_compute::cpu::neon_fp32_logits) }, #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) { "neon_fp16_logits_1d_max", [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F16); }, - REGISTER_FP16_NEON(arm_compute::cpu::neon_logits_1d_max<float16_t>) + REGISTER_FP16_NEON(arm_compute::cpu::neon_fp16_logits) }, #endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */ { "neon_qu8_logits_1d_max", [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8); }, - REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_logits_1d_max<qasymm8_t>) + REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_qasymm8_logits) }, { "neon_qs8_logits_1d_max", [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); }, - REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_logits_1d_max<qasymm8_signed_t>) + REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_qasymm8_singed_logits) }, #endif /* defined(ARM_COMPUTE_ENABLE_NEON) */ }; |