aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/cpu/kernels/softmax/generic/sve2/impl.cpp20
1 files changed, 10 insertions, 10 deletions
diff --git a/src/cpu/kernels/softmax/generic/sve2/impl.cpp b/src/cpu/kernels/softmax/generic/sve2/impl.cpp
index 9cdfe61446..8f677c62d4 100644
--- a/src/cpu/kernels/softmax/generic/sve2/impl.cpp
+++ b/src/cpu/kernels/softmax/generic/sve2/impl.cpp
@@ -80,13 +80,13 @@ void sve2_softmax_logits_1d_quantized(const ITensor *in, const ITensor *max, voi
svbool_t pg_3 = svunpkhi(svunpkhi(pg));
do
{
- auto vec_elements = svld1(pg, in_ptr + x);
- vec_elements = svsub_z(pg, vec_max, vec_elements);
+ const auto vec_elements = svld1(pg, in_ptr + x);
+ const auto vec_elements_sub = svreinterpret_u8(svsub_z(pg, vec_max, vec_elements));
- auto vec_elements_flt_0 = svcvt_f32_z(pg_0, svunpklo(svunpklo(vec_elements)));
- auto vec_elements_flt_1 = svcvt_f32_z(pg_1, svunpkhi(svunpklo(vec_elements)));
- auto vec_elements_flt_2 = svcvt_f32_z(pg_2, svunpklo(svunpkhi(vec_elements)));
- auto vec_elements_flt_3 = svcvt_f32_z(pg_3, svunpkhi(svunpkhi(vec_elements)));
+ auto vec_elements_flt_0 = svcvt_f32_z(pg_0, svunpklo(svunpklo(vec_elements_sub)));
+ auto vec_elements_flt_1 = svcvt_f32_z(pg_1, svunpkhi(svunpklo(vec_elements_sub)));
+ auto vec_elements_flt_2 = svcvt_f32_z(pg_2, svunpklo(svunpkhi(vec_elements_sub)));
+ auto vec_elements_flt_3 = svcvt_f32_z(pg_3, svunpkhi(svunpkhi(vec_elements_sub)));
if(is_log)
{
@@ -180,10 +180,10 @@ void sve2_softmax_logits_1d_quantized(const ITensor *in, const ITensor *max, voi
if(is_qasymm8_signed)
{
const auto offset_vec = svdup_n_f32(128.f);
- res_0 = svsub_z(pg_0, vec_in_0, offset_vec);
- res_1 = svsub_z(pg_1, vec_in_1, offset_vec);
- res_2 = svsub_z(pg_2, vec_in_2, offset_vec);
- res_3 = svsub_z(pg_3, vec_in_3, offset_vec);
+ res_0 = svsub_z(pg_0, res_0, offset_vec);
+ res_1 = svsub_z(pg_1, res_1, offset_vec);
+ res_2 = svsub_z(pg_2, res_2, offset_vec);
+ res_3 = svsub_z(pg_3, res_3, offset_vec);
}
}