diff options
Diffstat (limited to 'src/core/cpu/kernels/CpuSoftmaxKernel.cpp')
-rw-r--r-- | src/core/cpu/kernels/CpuSoftmaxKernel.cpp | 49 |
1 files changed, 24 insertions, 25 deletions
diff --git a/src/core/cpu/kernels/CpuSoftmaxKernel.cpp b/src/core/cpu/kernels/CpuSoftmaxKernel.cpp index 8ea186b16a..1e00e12050 100644 --- a/src/core/cpu/kernels/CpuSoftmaxKernel.cpp +++ b/src/core/cpu/kernels/CpuSoftmaxKernel.cpp @@ -47,7 +47,8 @@ namespace { struct SoftmaxSelectorData { - DataType dt; + DataType dt; + const CPUInfo &ci; }; using SoftmaxSelectorPtr = std::add_pointer<bool(const SoftmaxSelectorData &data)>::type; using SoftmaxLogits1DMaxKernelPtr = std::add_pointer<void(const ITensor *, ITensor *, const Window &)>::type; @@ -69,20 +70,20 @@ struct SoftmaxLogits1DMaxKernel static const SoftmaxLogits1DKernel available_logits_1d_kernels[] = { -#if defined(ENABLE_SVE) +#if defined(ARM_COMPUTE_ENABLE_SVE) { "sve_softmax_logits_1d_float", - [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32); }, + [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32) && data.ci.has_sve(); }, REGISTER_FP32_SVE(arm_compute::cpu::sve_softmax_logits_1d_float<float>) }, { "sve_softmax_logits_1d_float", - [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F16); }, + [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F16) && data.ci.has_sve(); }, REGISTER_FP16_SVE(arm_compute::cpu::sve_softmax_logits_1d_float<float16_t>) }, -#endif /* defined(ENABLE_SVE) */ +#endif /* defined(ARM_COMPUTE_ENABLE_SVE) */ -#if defined(ENABLE_NEON) +#if defined(ARM_COMPUTE_ENABLE_NEON) { "neon_softmax_logits_1d_float", [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32); }, @@ -95,20 +96,20 @@ static const SoftmaxLogits1DKernel available_logits_1d_kernels[] = REGISTER_FP16_NEON(arm_compute::cpu::neon_softmax_logits_1d_float<float16_t>) }, #endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */ -#endif /* !defined(ENABLE_NEON) */ +#endif /* defined(ARM_COMPUTE_ENABLE_NEON) */ -#if defined(__ARM_FEATURE_SVE2) +#if defined(ARM_COMPUTE_ENABLE_SVE2) { "sve_softmax_logits_1d_quantized", - [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8); }, + [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8) && data.ci.has_sve2(); }, REGISTER_QASYMM8_SVE(arm_compute::cpu::sve_softmax_logits_1d_quantized<qasymm8_t>) }, { "sve_softmax_logits_1d_quantized", - [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); }, + [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED) && data.ci.has_sve2(); }, REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::sve_softmax_logits_1d_quantized<qasymm8_signed_t>) }, -#else /* !defined(__ARM_FEATURE_SVE2) */ +#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */ { "neon_softmax_logits_1d_quantized", [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8); }, @@ -119,35 +120,33 @@ static const SoftmaxLogits1DKernel available_logits_1d_kernels[] = [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); }, REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_softmax_logits_1d_quantized<qasymm8_signed_t>) }, -#endif /* defined(__ARM_FEATURE_SVE2) */ - }; static const SoftmaxLogits1DMaxKernel available_logits_1d_max_kernels[] = { -#if defined(ENABLE_SVE) +#if defined(ARM_COMPUTE_ENABLE_SVE) { "sve_logits_1d_max", - [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32); }, + [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32) && data.ci.has_sve(); }, REGISTER_FP32_SVE(arm_compute::cpu::sve_logits_1d_max<float>) }, { "sve_logits_1d_max", - [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F16); }, + [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F16) && data.ci.has_sve(); }, REGISTER_FP16_SVE(arm_compute::cpu::sve_logits_1d_max<float16_t>) }, { "sve_logits_1d_max", - [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8); }, + [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8) && data.ci.has_sve(); }, REGISTER_QASYMM8_SVE(arm_compute::cpu::sve_logits_1d_max<qasymm8_t>) }, { "sve_logits_1d_max", - [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); }, + [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED) && data.ci.has_sve(); }, REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::sve_logits_1d_max<qasymm8_signed_t>) }, -#endif /* defined(ENABLE_SVE) */ -#if defined(ENABLE_NEON) +#endif /* defined(ARM_COMPUTE_ENABLE_SVE) */ +#if defined(ARM_COMPUTE_ENABLE_NEON) { "neon_logits_1d_max", [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32); }, @@ -170,14 +169,14 @@ static const SoftmaxLogits1DMaxKernel available_logits_1d_max_kernels[] = [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); }, REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_logits_1d_max<qasymm8_signed_t>) }, -#endif /* defined(ENABLE_NEON) */ +#endif /* defined(ARM_COMPUTE_ENABLE_NEON) */ }; const SoftmaxLogits1DKernel *get_implementation_logits(const SoftmaxSelectorData &data) { for(const auto &uk : available_logits_1d_kernels) { - if(uk.is_selected({ data.dt })) + if(uk.is_selected({ data.dt, CPUInfo::get() })) { return &uk; } @@ -189,7 +188,7 @@ const SoftmaxLogits1DMaxKernel *get_implementation_logits_max(const SoftmaxSelec { for(const auto &uk : available_logits_1d_max_kernels) { - if(uk.is_selected({ data.dt })) + if(uk.is_selected({ data.dt, CPUInfo::get() })) { return &uk; } @@ -253,7 +252,7 @@ void CpuLogits1DMaxKernel::run_op(ITensorPack &tensors, const Window &window, co const auto src = tensors.get_const_tensor(TensorType::ACL_SRC); auto dst = tensors.get_tensor(TensorType::ACL_DST); - const auto *uk = get_implementation_logits_max(SoftmaxSelectorData{ src->info()->data_type() }); + const auto *uk = get_implementation_logits_max(SoftmaxSelectorData{ src->info()->data_type(), CPUInfo::get() }); uk->ukernel(src, dst, window); } @@ -364,7 +363,7 @@ void CpuLogits1DSoftmaxKernel<IS_LOG>::run_op(ITensorPack &tensors, const Window void *tmp_for_thread = tmp->buffer() + (info.thread_id * tmp_size_for_thread); - const auto *uk = get_implementation_logits(SoftmaxSelectorData{ src->info()->data_type() }); + const auto *uk = get_implementation_logits(SoftmaxSelectorData{ src->info()->data_type(), CPUInfo::get() }); uk->ukernel(src, max, tmp_for_thread, dst, _beta, IS_LOG, window); } |