aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuSoftmaxKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuSoftmaxKernel.cpp')
-rw-r--r--src/cpu/kernels/CpuSoftmaxKernel.cpp12
1 files changed, 7 insertions, 5 deletions
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.cpp b/src/cpu/kernels/CpuSoftmaxKernel.cpp
index a088fb6660..5cf81f815c 100644
--- a/src/cpu/kernels/CpuSoftmaxKernel.cpp
+++ b/src/cpu/kernels/CpuSoftmaxKernel.cpp
@@ -50,15 +50,17 @@ namespace
{
/* Softmax */
static const std::vector<typename CpuSoftmaxKernel::SoftmaxKernel> available_kernels = {
-#ifdef ARM_COMPUTE_ENABLE_SME2
{"sme2_fp32_softmax",
[](const SoftmaxKernelDataTypeISASelectorData &data)
- { return (!data.is_log && data.dt == DataType::F32 && data.isa.sme2); },
- REGISTER_FP32_NEON(sme2_fp32_softmax)},
-#endif // ARM_COMPUTE_ENABLE_SME2
+ { return (!data.is_log && data.dt == DataType::F32 && data.isa.sme2 && data.axis == 0); },
+ REGISTER_FP32_SME2(sme2_fp32_softmax)},
{"neon_fp32_softmax",
[](const SoftmaxKernelDataTypeISASelectorData &data) { return (!data.is_log && data.dt == DataType::F32); },
REGISTER_FP32_NEON(neon_fp32_softmax<false>)},
+ {"sme2_fp16_softmax",
+ [](const SoftmaxKernelDataTypeISASelectorData &data)
+ { return (!data.is_log && data.dt == DataType::F16 && data.isa.sme2 && data.axis == 0); },
+ REGISTER_FP16_SME2(sme2_fp16_softmax)},
{"neon_fp16_softmax",
[](const SoftmaxKernelDataTypeISASelectorData &data)
{ return (!data.is_log && data.dt == DataType::F16) && data.isa.fp16; },
@@ -156,7 +158,7 @@ void CpuSoftmaxKernel::configure(
}
const auto *uk = CpuSoftmaxKernel::get_implementation(
- SoftmaxKernelDataTypeISASelectorData{src->data_type(), CPUInfo::get().get_isa(), is_log});
+ SoftmaxKernelDataTypeISASelectorData{src->data_type(), CPUInfo::get().get_isa(), is_log, axis});
ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
std::string kernel_name = is_log ? std::string("CpuLogSoftmaxKernel") : std::string("CpuSoftmaxKernel");