aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMurray Kornelsen <murray.kornelsen@mail.mcgill.ca>2022-07-22 18:04:59 -0400
committerPablo Marquez Tello <pablo.tello@arm.com>2022-09-02 10:44:06 +0000
commit552fe4c67d3cd2994cdbd5662cde79da5caf0c4d (patch)
tree501cc20a0f878e5d7542977180c843b1b7778784 /src
parent1257131193fdb9b6940055a41691320e37a208b5 (diff)
downloadComputeLibrary-552fe4c67d3cd2994cdbd5662cde79da5caf0c4d.tar.gz
F16 Specialization for MeanStdDevNorm
Ran into issues with f16 meanstddevnorm. Essentially, with large enough tensors and/or large values in tensors, output becomes all 0. This is due to the variance computation. In f16, it reaches infinity quite easily, then the division results in 0. This change modifies the OpenCL and NEON implementations to compute the sum of squares and the variance using f32, while other operations remain f16. Update: Found that the square operation also benefits from f32, rather than squaring in f16 and accumulating f32. Signed-off-by: Murray Kornelsen <murray.kornelsen@mail.mcgill.ca> Change-Id: Ide00afd84ec6d26fec4d53b073e295814f08ba46 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7959 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Comments-Addressed: Pablo Marquez Tello <pablo.tello@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/core/CL/cl_kernels/common/mean_stddev_normalization.cl12
-rw-r--r--src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp3
-rw-r--r--src/cpu/kernels/meanstddevnorm/generic/neon/impl.cpp74
3 files changed, 86 insertions, 3 deletions
diff --git a/src/core/CL/cl_kernels/common/mean_stddev_normalization.cl b/src/core/CL/cl_kernels/common/mean_stddev_normalization.cl
index 05727a6aa6..22abf64874 100644
--- a/src/core/CL/cl_kernels/common/mean_stddev_normalization.cl
+++ b/src/core/CL/cl_kernels/common/mean_stddev_normalization.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2021 Arm Limited.
+ * Copyright (c) 2019-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -62,7 +62,11 @@ __kernel void mean_stddev_normalization(
VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
sum = 0.f;
+#ifdef MEANSTDNORM_HALF
+ VEC_DATA_TYPE(float, VEC_SIZE)
+#else /* MEANSTDNORM_HALF */
VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
+#endif /* MEANSTDNORM_HALF */
sum_sq = 0.f;
// Calculate partial sum
int i = 0;
@@ -73,7 +77,13 @@ __kernel void mean_stddev_normalization(
data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)offset(&in, i, 0));
sum += data;
+#ifdef MEANSTDNORM_HALF
+ VEC_DATA_TYPE(float, VEC_SIZE)
+ dsq = CONVERT(data * data, VEC_DATA_TYPE(float, VEC_SIZE));
+ sum_sq += dsq;
+#else /* MEANSTDNORM_HALF */
sum_sq += data * data;
+#endif /* MEANSTDNORM_HALF */
}
// Perform reduction
sum = SUM_REDUCE(sum, VEC_SIZE);
diff --git a/src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp b/src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp
index da9e367590..b94593943c 100644
--- a/src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp
+++ b/src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2021 Arm Limited.
+ * Copyright (c) 2019-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -91,6 +91,7 @@ void CLMeanStdDevNormalizationKernel::configure(const CLCompileContext &compile_
build_opts.add_option("-DVEC_SIZE=" + support::cpp11::to_string(num_elems_processed_per_iteration));
build_opts.add_option("-DEPSILON=" + float_to_string_with_full_precision(epsilon));
build_opts.add_option("-DWIDTH=" + support::cpp11::to_string(input->info()->dimension(0)));
+ build_opts.add_option_if(input->info()->data_type() == DataType::F16, "-DMEANSTDNORM_HALF");
build_opts.add_option_if(_run_in_place, "-DIN_PLACE");
// Create kernel
diff --git a/src/cpu/kernels/meanstddevnorm/generic/neon/impl.cpp b/src/cpu/kernels/meanstddevnorm/generic/neon/impl.cpp
index be07ea78e4..0d00acdd0c 100644
--- a/src/cpu/kernels/meanstddevnorm/generic/neon/impl.cpp
+++ b/src/cpu/kernels/meanstddevnorm/generic/neon/impl.cpp
@@ -103,7 +103,79 @@ void mean_stddev_normalization(ITensor *input, ITensor *output, float epsilon, c
template void mean_stddev_normalization<float, 4>(ITensor *input, ITensor *output, float epsilon, const Window &window);
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
-template void mean_stddev_normalization<float16_t, 8>(ITensor *input, ITensor *output, float epsilon, const Window &window);
+template <>
+void mean_stddev_normalization<float16_t, 8>(ITensor *input, ITensor *output, float epsilon, const Window &window)
+{
+ // Set build options
+ Window win = window;
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ const int window_step_x = 8;
+ const auto window_start_x = static_cast<int>(window.x().start());
+ const auto window_end_x = static_cast<int>(window.x().end());
+
+ Iterator input_itr(input, win);
+ Iterator output_itr(output, win);
+
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ int x = window_start_x;
+ auto in_ptr = reinterpret_cast<const float16_t *>(input_itr.ptr());
+ auto out_ptr = reinterpret_cast<float16_t *>(output_itr.ptr());
+
+ float16x8_t sum_vec = vdupq_n_f16(static_cast<float16_t>(0.0f));
+ float32x4_t sum_sq_vec = vdupq_n_f32(0.0f);
+
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ float16x8_t data = vld1q_f16(in_ptr + x);
+ sum_vec = vaddq_f16(sum_vec, data);
+ float32x4_t dl = vcvt_f32_f16(vget_low_f16(data));
+ float32x4_t dh = vcvt_f32_f16(vget_high_f16(data));
+ sum_sq_vec = vaddq_f32(sum_sq_vec, vmulq_f32(dl, dl));
+ sum_sq_vec = vaddq_f32(sum_sq_vec, vmulq_f32(dh, dh));
+ }
+
+ float16x4_t sum_carry_res = vpadd_f16(vget_high_f16(sum_vec), vget_low_f16(sum_vec));
+ sum_carry_res = vpadd_f16(sum_carry_res, sum_carry_res);
+ sum_carry_res = vpadd_f16(sum_carry_res, sum_carry_res);
+
+ float32x4_t sum_sq_carry_res = vpaddq_f32(sum_sq_vec, sum_sq_vec);
+ sum_sq_carry_res = vpaddq_f32(sum_sq_carry_res, sum_sq_carry_res);
+
+ float16_t sum = vget_lane_f16(sum_carry_res, 0);
+ float sum_sq = vgetq_lane_f32(sum_sq_carry_res, 0);
+
+ // Compute left-over elements
+ for(; x < window_end_x; ++x)
+ {
+ float16_t data = *(in_ptr + x);
+ sum += data;
+ float fdata = static_cast<float>(data);
+ sum_sq += fdata * fdata;
+ }
+
+ float16_t mean = sum / input->info()->dimension(0);
+ float var = (sum_sq / input->info()->dimension(0)) - (mean * mean);
+ float16_t stddev_inv = static_cast<float16_t>(1.f / sqrt(var + epsilon));
+
+ float16x8_t mean_vec = vdupq_n_f16(mean);
+ float16x8_t stddev_inv_vec = vdupq_n_f16(stddev_inv);
+
+ for(x = window_start_x; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ float16x8_t data = vld1q_f16(in_ptr + x);
+ float16x8_t res = vmulq_f16(vsubq_f16(data, mean_vec), stddev_inv_vec);
+ // Store results
+ vst1q_f16(out_ptr + x, res);
+ }
+ for(; x < window_end_x; ++x)
+ {
+ *(out_ptr + x) = (*(in_ptr + x) - mean) * stddev_inv;
+ }
+ },
+ input_itr, output_itr);
+}
#endif //defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
} // namespace cpu