diff options
Diffstat (limited to 'src/cpu/kernels/softmax/generic/sve/impl.cpp')
-rw-r--r-- | src/cpu/kernels/softmax/generic/sve/impl.cpp | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/src/cpu/kernels/softmax/generic/sve/impl.cpp b/src/cpu/kernels/softmax/generic/sve/impl.cpp index f1442224e8..2340a31cbd 100644 --- a/src/cpu/kernels/softmax/generic/sve/impl.cpp +++ b/src/cpu/kernels/softmax/generic/sve/impl.cpp @@ -94,8 +94,9 @@ void sve_softmax_logits_1d_float(const ITensor *in, const ITensor *max, void *co /* Compute exponentials and sum */ { /* Get max value */ - const auto max_val = *reinterpret_cast<const ScalarType *>(max_it.ptr()); - const auto vec_max = wrapper::svdup_n(max_val); + const auto max_val = *reinterpret_cast<const ScalarType *>(max_it.ptr()); + const auto vec_max = wrapper::svdup_n(max_val); + const auto vec_beta = wrapper::svdup_n(static_cast<ScalarType>(beta)); /* Init sum to zero */ auto vec_sum = wrapper::svdup_n(static_cast<ScalarType>(0)); @@ -106,19 +107,19 @@ void sve_softmax_logits_1d_float(const ITensor *in, const ITensor *max, void *co do { auto vec_elements = svld1(pg, in_ptr + x); - vec_elements = svsub_z(pg, vec_elements, vec_max); - if(is_log) - { - vec_elements = svmul_z(pg, vec_elements, wrapper::svdup_n(static_cast<ScalarType>(beta))); - vec_sum = svadd_m(pg, vec_sum, wrapper::svexp_z(pg, vec_elements)); - } - else + vec_elements = svmul_z(pg, svsub_z(pg, vec_elements, vec_max), vec_beta); + if(!is_log) { - vec_elements = wrapper::svexp_z(pg, svmul_z(pg, vec_elements, wrapper::svdup_n(static_cast<ScalarType>(beta)))); + vec_elements = wrapper::svexp_z(pg, vec_elements); vec_sum = svadd_m(pg, vec_sum, vec_elements); } svst1(pg, tmp_ptr + x, vec_elements); + if(is_log) + { + vec_sum = svadd_m(pg, vec_sum, wrapper::svexp_z(pg, vec_elements)); + } + x += wrapper::svcnt<ScalarType>(); pg = wrapper::svwhilelt<ScalarType>(x, input_width); } |