aboutsummaryrefslogtreecommitdiff
path: root/src/core/cpu/kernels/CpuSoftmaxKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/cpu/kernels/CpuSoftmaxKernel.cpp')
-rw-r--r--src/core/cpu/kernels/CpuSoftmaxKernel.cpp49
1 files changed, 24 insertions, 25 deletions
diff --git a/src/core/cpu/kernels/CpuSoftmaxKernel.cpp b/src/core/cpu/kernels/CpuSoftmaxKernel.cpp
index 8ea186b16a..1e00e12050 100644
--- a/src/core/cpu/kernels/CpuSoftmaxKernel.cpp
+++ b/src/core/cpu/kernels/CpuSoftmaxKernel.cpp
@@ -47,7 +47,8 @@ namespace
{
struct SoftmaxSelectorData
{
- DataType dt;
+ DataType dt;
+ const CPUInfo &ci;
};
using SoftmaxSelectorPtr = std::add_pointer<bool(const SoftmaxSelectorData &data)>::type;
using SoftmaxLogits1DMaxKernelPtr = std::add_pointer<void(const ITensor *, ITensor *, const Window &)>::type;
@@ -69,20 +70,20 @@ struct SoftmaxLogits1DMaxKernel
static const SoftmaxLogits1DKernel available_logits_1d_kernels[] =
{
-#if defined(ENABLE_SVE)
+#if defined(ARM_COMPUTE_ENABLE_SVE)
{
"sve_softmax_logits_1d_float",
- [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32); },
+ [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32) && data.ci.has_sve(); },
REGISTER_FP32_SVE(arm_compute::cpu::sve_softmax_logits_1d_float<float>)
},
{
"sve_softmax_logits_1d_float",
- [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F16); },
+ [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F16) && data.ci.has_sve(); },
REGISTER_FP16_SVE(arm_compute::cpu::sve_softmax_logits_1d_float<float16_t>)
},
-#endif /* defined(ENABLE_SVE) */
+#endif /* defined(ARM_COMPUTE_ENABLE_SVE) */
-#if defined(ENABLE_NEON)
+#if defined(ARM_COMPUTE_ENABLE_NEON)
{
"neon_softmax_logits_1d_float",
[](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32); },
@@ -95,20 +96,20 @@ static const SoftmaxLogits1DKernel available_logits_1d_kernels[] =
REGISTER_FP16_NEON(arm_compute::cpu::neon_softmax_logits_1d_float<float16_t>)
},
#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */
-#endif /* !defined(ENABLE_NEON) */
+#endif /* defined(ARM_COMPUTE_ENABLE_NEON) */
-#if defined(__ARM_FEATURE_SVE2)
+#if defined(ARM_COMPUTE_ENABLE_SVE2)
{
"sve_softmax_logits_1d_quantized",
- [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8); },
+ [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8) && data.ci.has_sve2(); },
REGISTER_QASYMM8_SVE(arm_compute::cpu::sve_softmax_logits_1d_quantized<qasymm8_t>)
},
{
"sve_softmax_logits_1d_quantized",
- [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
+ [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED) && data.ci.has_sve2(); },
REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::sve_softmax_logits_1d_quantized<qasymm8_signed_t>)
},
-#else /* !defined(__ARM_FEATURE_SVE2) */
+#endif /* defined(ARM_COMPUTE_ENABLE_SVE2) */
{
"neon_softmax_logits_1d_quantized",
[](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8); },
@@ -119,35 +120,33 @@ static const SoftmaxLogits1DKernel available_logits_1d_kernels[] =
[](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_softmax_logits_1d_quantized<qasymm8_signed_t>)
},
-#endif /* defined(__ARM_FEATURE_SVE2) */
-
};
static const SoftmaxLogits1DMaxKernel available_logits_1d_max_kernels[] =
{
-#if defined(ENABLE_SVE)
+#if defined(ARM_COMPUTE_ENABLE_SVE)
{
"sve_logits_1d_max",
- [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32); },
+ [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32) && data.ci.has_sve(); },
REGISTER_FP32_SVE(arm_compute::cpu::sve_logits_1d_max<float>)
},
{
"sve_logits_1d_max",
- [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F16); },
+ [](const SoftmaxSelectorData & data) { return (data.dt == DataType::F16) && data.ci.has_sve(); },
REGISTER_FP16_SVE(arm_compute::cpu::sve_logits_1d_max<float16_t>)
},
{
"sve_logits_1d_max",
- [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8); },
+ [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8) && data.ci.has_sve(); },
REGISTER_QASYMM8_SVE(arm_compute::cpu::sve_logits_1d_max<qasymm8_t>)
},
{
"sve_logits_1d_max",
- [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
+ [](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED) && data.ci.has_sve(); },
REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::sve_logits_1d_max<qasymm8_signed_t>)
},
-#endif /* defined(ENABLE_SVE) */
-#if defined(ENABLE_NEON)
+#endif /* defined(ARM_COMPUTE_ENABLE_SVE) */
+#if defined(ARM_COMPUTE_ENABLE_NEON)
{
"neon_logits_1d_max",
[](const SoftmaxSelectorData & data) { return (data.dt == DataType::F32); },
@@ -170,14 +169,14 @@ static const SoftmaxLogits1DMaxKernel available_logits_1d_max_kernels[] =
[](const SoftmaxSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_logits_1d_max<qasymm8_signed_t>)
},
-#endif /* defined(ENABLE_NEON) */
+#endif /* defined(ARM_COMPUTE_ENABLE_NEON) */
};
const SoftmaxLogits1DKernel *get_implementation_logits(const SoftmaxSelectorData &data)
{
for(const auto &uk : available_logits_1d_kernels)
{
- if(uk.is_selected({ data.dt }))
+ if(uk.is_selected({ data.dt, CPUInfo::get() }))
{
return &uk;
}
@@ -189,7 +188,7 @@ const SoftmaxLogits1DMaxKernel *get_implementation_logits_max(const SoftmaxSelec
{
for(const auto &uk : available_logits_1d_max_kernels)
{
- if(uk.is_selected({ data.dt }))
+ if(uk.is_selected({ data.dt, CPUInfo::get() }))
{
return &uk;
}
@@ -253,7 +252,7 @@ void CpuLogits1DMaxKernel::run_op(ITensorPack &tensors, const Window &window, co
const auto src = tensors.get_const_tensor(TensorType::ACL_SRC);
auto dst = tensors.get_tensor(TensorType::ACL_DST);
- const auto *uk = get_implementation_logits_max(SoftmaxSelectorData{ src->info()->data_type() });
+ const auto *uk = get_implementation_logits_max(SoftmaxSelectorData{ src->info()->data_type(), CPUInfo::get() });
uk->ukernel(src, dst, window);
}
@@ -364,7 +363,7 @@ void CpuLogits1DSoftmaxKernel<IS_LOG>::run_op(ITensorPack &tensors, const Window
void *tmp_for_thread = tmp->buffer() + (info.thread_id * tmp_size_for_thread);
- const auto *uk = get_implementation_logits(SoftmaxSelectorData{ src->info()->data_type() });
+ const auto *uk = get_implementation_logits(SoftmaxSelectorData{ src->info()->data_type(), CPUInfo::get() });
uk->ukernel(src, max, tmp_for_thread, dst, _beta, IS_LOG, window);
}