aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/softmax/generic/sve/impl.cpp
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2022-06-28 19:46:42 +0100
committerMichalis Spyrou <michalis.spyrou@arm.com>2022-07-04 16:00:54 +0000
commite417ff1d9fde119a238582a3b1feb914edd95c38 (patch)
tree7fe8c97e277931bd7d40ce1c98d13851daba0939 /src/cpu/kernels/softmax/generic/sve/impl.cpp
parent84a0941cf5bdaffc6127d4cae2e949e6e9109e4a (diff)
downloadComputeLibrary-e417ff1d9fde119a238582a3b1feb914edd95c38.tar.gz
Fix build errors on armv8.6 SVE2 with NDK 23 and 24
Extensive use of templates resulted in a compiler crash on NDK 23 and 24. This rework solves the issue and also reduces the library size by 101Kb. Resolves: COMPMID-5384 Change-Id: I9c5c68c5e36f236b0891e44d25478743417fb16d Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7871 Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/kernels/softmax/generic/sve/impl.cpp')
-rw-r--r--src/cpu/kernels/softmax/generic/sve/impl.cpp21
1 files changed, 11 insertions, 10 deletions
diff --git a/src/cpu/kernels/softmax/generic/sve/impl.cpp b/src/cpu/kernels/softmax/generic/sve/impl.cpp
index f1442224e8..2340a31cbd 100644
--- a/src/cpu/kernels/softmax/generic/sve/impl.cpp
+++ b/src/cpu/kernels/softmax/generic/sve/impl.cpp
@@ -94,8 +94,9 @@ void sve_softmax_logits_1d_float(const ITensor *in, const ITensor *max, void *co
/* Compute exponentials and sum */
{
/* Get max value */
- const auto max_val = *reinterpret_cast<const ScalarType *>(max_it.ptr());
- const auto vec_max = wrapper::svdup_n(max_val);
+ const auto max_val = *reinterpret_cast<const ScalarType *>(max_it.ptr());
+ const auto vec_max = wrapper::svdup_n(max_val);
+ const auto vec_beta = wrapper::svdup_n(static_cast<ScalarType>(beta));
/* Init sum to zero */
auto vec_sum = wrapper::svdup_n(static_cast<ScalarType>(0));
@@ -106,19 +107,19 @@ void sve_softmax_logits_1d_float(const ITensor *in, const ITensor *max, void *co
do
{
auto vec_elements = svld1(pg, in_ptr + x);
- vec_elements = svsub_z(pg, vec_elements, vec_max);
- if(is_log)
- {
- vec_elements = svmul_z(pg, vec_elements, wrapper::svdup_n(static_cast<ScalarType>(beta)));
- vec_sum = svadd_m(pg, vec_sum, wrapper::svexp_z(pg, vec_elements));
- }
- else
+ vec_elements = svmul_z(pg, svsub_z(pg, vec_elements, vec_max), vec_beta);
+ if(!is_log)
{
- vec_elements = wrapper::svexp_z(pg, svmul_z(pg, vec_elements, wrapper::svdup_n(static_cast<ScalarType>(beta))));
+ vec_elements = wrapper::svexp_z(pg, vec_elements);
vec_sum = svadd_m(pg, vec_sum, vec_elements);
}
svst1(pg, tmp_ptr + x, vec_elements);
+ if(is_log)
+ {
+ vec_sum = svadd_m(pg, vec_sum, wrapper::svexp_z(pg, vec_elements));
+ }
+
x += wrapper::svcnt<ScalarType>();
pg = wrapper::svwhilelt<ScalarType>(x, input_width);
}