aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPablo Marquez Tello <pablo.tello@arm.com>2024-03-18 16:24:45 +0000
committerPablo Marquez Tello <pablo.tello@arm.com>2024-03-19 17:17:41 +0000
commitc00a82b1467c02e46093b65f29aaad8fbf794dfe (patch)
treea74884cf17f4ac9af7a79bb54e2f8c5130942360
parent3e4b193f783c2d43547123518cadd1b2a9b11055 (diff)
downloadComputeLibrary-c00a82b1467c02e46093b65f29aaad8fbf794dfe.tar.gz
Fix overflow in NEMeanStdDevNormalizationKernel
* Perform final sum in fp32 to avoid overflow * Resolves ARMCL-1128 Change-Id: I89799baf81045697f7bc44017fcb6a440635caff Signed-off-by: Pablo Marquez Tello <pablo.tello@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11311 Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp22
1 files changed, 8 insertions, 14 deletions
diff --git a/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp b/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp
index 6470f391e2..344b9df0c8 100644
--- a/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp
+++ b/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -66,26 +66,20 @@ void mean_stddev_normalization<float16_t, 8>(ITensor *input, ITensor *output, fl
sum_sq_vec = vaddq_f32(sum_sq_vec, vmulq_f32(dh, dh));
}
- float16x4_t sum_carry_res = vpadd_f16(vget_high_f16(sum_vec), vget_low_f16(sum_vec));
- sum_carry_res = vpadd_f16(sum_carry_res, sum_carry_res);
- sum_carry_res = vpadd_f16(sum_carry_res, sum_carry_res);
-
- float32x4_t sum_sq_carry_res = vpaddq_f32(sum_sq_vec, sum_sq_vec);
- sum_sq_carry_res = vpaddq_f32(sum_sq_carry_res, sum_sq_carry_res);
-
- float16_t sum = vget_lane_f16(sum_carry_res, 0);
- float sum_sq = vgetq_lane_f32(sum_sq_carry_res, 0);
+ float32x4_t sum_carry_res =
+ vpaddq_f32(vcvt_f32_f16(vget_high_f16(sum_vec)), vcvt_f32_f16(vget_low_f16(sum_vec)));
+ float sum = vaddvq_f32(sum_carry_res);
+ float sum_sq = vaddvq_f32(sum_sq_vec);
// Compute left-over elements
for (; x < window_end_x; ++x)
{
- float16_t data = *(in_ptr + x);
- sum += data;
- float fdata = static_cast<float>(data);
+ const float fdata = static_cast<float>(*(in_ptr + x));
+ sum += fdata;
sum_sq += fdata * fdata;
}
- float16_t mean = sum / input->info()->dimension(0);
+ float16_t mean = static_cast<float16_t>(sum / input->info()->dimension(0));
float var = (sum_sq / input->info()->dimension(0)) - (mean * mean);
float16_t stddev_inv = static_cast<float16_t>(1.f / sqrt(var + epsilon));