diff options
author | Gunes Bayir <gunes.bayir@arm.com> | 2024-04-09 23:13:04 +0100 |
---|---|---|
committer | Gunes Bayir <gunes.bayir@arm.com> | 2024-04-11 12:58:45 +0000 |
commit | cfca87b91def4f455630f2094447dc0500b6256c (patch) | |
tree | 9985ca8ad1910d48a84aa9781fe3156e614ff5f4 /src/cpu/kernels/CpuSoftmaxKernel.cpp | |
parent | f1f1f87132690a8061801ef1a4638d637c780df7 (diff) | |
download | ComputeLibrary-cfca87b91def4f455630f2094447dc0500b6256c.tar.gz |
Add SME2 implementation of softmax for FP16
In addition to the softmax kernel, this patch fixes minor issues in the fp32 implementation.
Resolves: COMPMID-6920
Change-Id: Ibbd9f0af5f2a93fba0e92d72ba437279c34149d3
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11402
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/kernels/CpuSoftmaxKernel.cpp')
-rw-r--r-- | src/cpu/kernels/CpuSoftmaxKernel.cpp | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.cpp b/src/cpu/kernels/CpuSoftmaxKernel.cpp index a088fb6660..5cf81f815c 100644 --- a/src/cpu/kernels/CpuSoftmaxKernel.cpp +++ b/src/cpu/kernels/CpuSoftmaxKernel.cpp @@ -50,15 +50,17 @@ namespace { /* Softmax */ static const std::vector<typename CpuSoftmaxKernel::SoftmaxKernel> available_kernels = { -#ifdef ARM_COMPUTE_ENABLE_SME2 {"sme2_fp32_softmax", [](const SoftmaxKernelDataTypeISASelectorData &data) - { return (!data.is_log && data.dt == DataType::F32 && data.isa.sme2); }, - REGISTER_FP32_NEON(sme2_fp32_softmax)}, -#endif // ARM_COMPUTE_ENABLE_SME2 + { return (!data.is_log && data.dt == DataType::F32 && data.isa.sme2 && data.axis == 0); }, + REGISTER_FP32_SME2(sme2_fp32_softmax)}, {"neon_fp32_softmax", [](const SoftmaxKernelDataTypeISASelectorData &data) { return (!data.is_log && data.dt == DataType::F32); }, REGISTER_FP32_NEON(neon_fp32_softmax<false>)}, + {"sme2_fp16_softmax", + [](const SoftmaxKernelDataTypeISASelectorData &data) + { return (!data.is_log && data.dt == DataType::F16 && data.isa.sme2 && data.axis == 0); }, + REGISTER_FP16_SME2(sme2_fp16_softmax)}, {"neon_fp16_softmax", [](const SoftmaxKernelDataTypeISASelectorData &data) { return (!data.is_log && data.dt == DataType::F16) && data.isa.fp16; }, @@ -156,7 +158,7 @@ void CpuSoftmaxKernel::configure( } const auto *uk = CpuSoftmaxKernel::get_implementation( - SoftmaxKernelDataTypeISASelectorData{src->data_type(), CPUInfo::get().get_isa(), is_log}); + SoftmaxKernelDataTypeISASelectorData{src->data_type(), CPUInfo::get().get_isa(), is_log, axis}); ARM_COMPUTE_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); std::string kernel_name = is_log ? std::string("CpuLogSoftmaxKernel") : std::string("CpuSoftmaxKernel"); |