aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/softmax/generic/sve/impl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/softmax/generic/sve/impl.cpp')
-rw-r--r--src/cpu/kernels/softmax/generic/sve/impl.cpp21
1 files changed, 11 insertions, 10 deletions
diff --git a/src/cpu/kernels/softmax/generic/sve/impl.cpp b/src/cpu/kernels/softmax/generic/sve/impl.cpp
index f1442224e8..2340a31cbd 100644
--- a/src/cpu/kernels/softmax/generic/sve/impl.cpp
+++ b/src/cpu/kernels/softmax/generic/sve/impl.cpp
@@ -94,8 +94,9 @@ void sve_softmax_logits_1d_float(const ITensor *in, const ITensor *max, void *co
/* Compute exponentials and sum */
{
/* Get max value */
- const auto max_val = *reinterpret_cast<const ScalarType *>(max_it.ptr());
- const auto vec_max = wrapper::svdup_n(max_val);
+ const auto max_val = *reinterpret_cast<const ScalarType *>(max_it.ptr());
+ const auto vec_max = wrapper::svdup_n(max_val);
+ const auto vec_beta = wrapper::svdup_n(static_cast<ScalarType>(beta));
/* Init sum to zero */
auto vec_sum = wrapper::svdup_n(static_cast<ScalarType>(0));
@@ -106,19 +107,19 @@ void sve_softmax_logits_1d_float(const ITensor *in, const ITensor *max, void *co
do
{
auto vec_elements = svld1(pg, in_ptr + x);
- vec_elements = svsub_z(pg, vec_elements, vec_max);
- if(is_log)
- {
- vec_elements = svmul_z(pg, vec_elements, wrapper::svdup_n(static_cast<ScalarType>(beta)));
- vec_sum = svadd_m(pg, vec_sum, wrapper::svexp_z(pg, vec_elements));
- }
- else
+ vec_elements = svmul_z(pg, svsub_z(pg, vec_elements, vec_max), vec_beta);
+ if(!is_log)
{
- vec_elements = wrapper::svexp_z(pg, svmul_z(pg, vec_elements, wrapper::svdup_n(static_cast<ScalarType>(beta))));
+ vec_elements = wrapper::svexp_z(pg, vec_elements);
vec_sum = svadd_m(pg, vec_sum, vec_elements);
}
svst1(pg, tmp_ptr + x, vec_elements);
+ if(is_log)
+ {
+ vec_sum = svadd_m(pg, vec_sum, wrapper::svexp_z(pg, vec_elements));
+ }
+
x += wrapper::svcnt<ScalarType>();
pg = wrapper::svwhilelt<ScalarType>(x, input_width);
}