aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/cl_kernels/common/mean_stddev_normalization.cl
diff options
context:
space:
mode:
authorMurray Kornelsen <murray.kornelsen@mail.mcgill.ca>2022-07-22 18:04:59 -0400
committerPablo Marquez Tello <pablo.tello@arm.com>2022-09-02 10:44:06 +0000
commit552fe4c67d3cd2994cdbd5662cde79da5caf0c4d (patch)
tree501cc20a0f878e5d7542977180c843b1b7778784 /src/core/CL/cl_kernels/common/mean_stddev_normalization.cl
parent1257131193fdb9b6940055a41691320e37a208b5 (diff)
downloadComputeLibrary-552fe4c67d3cd2994cdbd5662cde79da5caf0c4d.tar.gz
F16 Specialization for MeanStdDevNorm
Ran into issues with f16 meanstddevnorm. Essentially, with large enough tensors and/or large values in tensors, output becomes all 0. This is due to the variance computation. In f16, it reaches infinity quite easily, then the division results in 0. This change modifies the OpenCL and NEON implementations to compute the sum of squares and the variance using f32, while other operations remain f16. Update: Found that the square operation also benefits from f32, rather than squaring in f16 and accumulating f32. Signed-off-by: Murray Kornelsen <murray.kornelsen@mail.mcgill.ca> Change-Id: Ide00afd84ec6d26fec4d53b073e295814f08ba46 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7959 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Comments-Addressed: Pablo Marquez Tello <pablo.tello@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/CL/cl_kernels/common/mean_stddev_normalization.cl')
-rw-r--r--src/core/CL/cl_kernels/common/mean_stddev_normalization.cl12
1 files changed, 11 insertions, 1 deletions
diff --git a/src/core/CL/cl_kernels/common/mean_stddev_normalization.cl b/src/core/CL/cl_kernels/common/mean_stddev_normalization.cl
index 05727a6aa6..22abf64874 100644
--- a/src/core/CL/cl_kernels/common/mean_stddev_normalization.cl
+++ b/src/core/CL/cl_kernels/common/mean_stddev_normalization.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2021 Arm Limited.
+ * Copyright (c) 2019-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -62,7 +62,11 @@ __kernel void mean_stddev_normalization(
VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
sum = 0.f;
+#ifdef MEANSTDNORM_HALF
+ VEC_DATA_TYPE(float, VEC_SIZE)
+#else /* MEANSTDNORM_HALF */
VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
+#endif /* MEANSTDNORM_HALF */
sum_sq = 0.f;
// Calculate partial sum
int i = 0;
@@ -73,7 +77,13 @@ __kernel void mean_stddev_normalization(
data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)offset(&in, i, 0));
sum += data;
+#ifdef MEANSTDNORM_HALF
+ VEC_DATA_TYPE(float, VEC_SIZE)
+ dsq = CONVERT(data * data, VEC_DATA_TYPE(float, VEC_SIZE));
+ sum_sq += dsq;
+#else /* MEANSTDNORM_HALF */
sum_sq += data * data;
+#endif /* MEANSTDNORM_HALF */
}
// Perform reduction
sum = SUM_REDUCE(sum, VEC_SIZE);