From c00a82b1467c02e46093b65f29aaad8fbf794dfe Mon Sep 17 00:00:00 2001 From: Pablo Marquez Tello Date: Mon, 18 Mar 2024 16:24:45 +0000 Subject: Fix overflow in NEMeanStdDevNormalizationKernel * Perform final sum in fp32 to avoid overflow * Resolves ARMCL-1128 Change-Id: I89799baf81045697f7bc44017fcb6a440635caff Signed-off-by: Pablo Marquez Tello Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11311 Reviewed-by: Gunes Bayir Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- .../kernels/meanstddevnorm/generic/neon/fp16.cpp | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp b/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp index 6470f391e2..344b9df0c8 100644 --- a/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp +++ b/src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023 Arm Limited. + * Copyright (c) 2022-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -66,26 +66,20 @@ void mean_stddev_normalization(ITensor *input, ITensor *output, fl sum_sq_vec = vaddq_f32(sum_sq_vec, vmulq_f32(dh, dh)); } - float16x4_t sum_carry_res = vpadd_f16(vget_high_f16(sum_vec), vget_low_f16(sum_vec)); - sum_carry_res = vpadd_f16(sum_carry_res, sum_carry_res); - sum_carry_res = vpadd_f16(sum_carry_res, sum_carry_res); - - float32x4_t sum_sq_carry_res = vpaddq_f32(sum_sq_vec, sum_sq_vec); - sum_sq_carry_res = vpaddq_f32(sum_sq_carry_res, sum_sq_carry_res); - - float16_t sum = vget_lane_f16(sum_carry_res, 0); - float sum_sq = vgetq_lane_f32(sum_sq_carry_res, 0); + float32x4_t sum_carry_res = + vpaddq_f32(vcvt_f32_f16(vget_high_f16(sum_vec)), vcvt_f32_f16(vget_low_f16(sum_vec))); + float sum = vaddvq_f32(sum_carry_res); + float sum_sq = vaddvq_f32(sum_sq_vec); // Compute left-over elements for (; x < window_end_x; ++x) { - float16_t data = *(in_ptr + x); - sum += data; - float fdata = static_cast(data); + const float fdata = static_cast(*(in_ptr + x)); + sum += fdata; sum_sq += fdata * fdata; } - float16_t mean = sum / input->info()->dimension(0); + float16_t mean = static_cast(sum / input->info()->dimension(0)); float var = (sum_sq / input->info()->dimension(0)) - (mean * mean); float16_t stddev_inv = static_cast(1.f / sqrt(var + epsilon)); -- cgit v1.2.1