diff options
Diffstat (limited to 'src/core/CL')
-rw-r--r-- | src/core/CL/cl_kernels/common/mean_stddev_normalization.cl | 12 | ||||
-rw-r--r-- | src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp | 3 |
2 files changed, 13 insertions, 2 deletions
diff --git a/src/core/CL/cl_kernels/common/mean_stddev_normalization.cl b/src/core/CL/cl_kernels/common/mean_stddev_normalization.cl index 05727a6aa6..22abf64874 100644 --- a/src/core/CL/cl_kernels/common/mean_stddev_normalization.cl +++ b/src/core/CL/cl_kernels/common/mean_stddev_normalization.cl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -62,7 +62,11 @@ __kernel void mean_stddev_normalization( VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) sum = 0.f; +#ifdef MEANSTDNORM_HALF + VEC_DATA_TYPE(float, VEC_SIZE) +#else /* MEANSTDNORM_HALF */ VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE) +#endif /* MEANSTDNORM_HALF */ sum_sq = 0.f; // Calculate partial sum int i = 0; @@ -73,7 +77,13 @@ __kernel void mean_stddev_normalization( data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)offset(&in, i, 0)); sum += data; +#ifdef MEANSTDNORM_HALF + VEC_DATA_TYPE(float, VEC_SIZE) + dsq = CONVERT(data * data, VEC_DATA_TYPE(float, VEC_SIZE)); + sum_sq += dsq; +#else /* MEANSTDNORM_HALF */ sum_sq += data * data; +#endif /* MEANSTDNORM_HALF */ } // Perform reduction sum = SUM_REDUCE(sum, VEC_SIZE); diff --git a/src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp b/src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp index da9e367590..b94593943c 100644 --- a/src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp +++ b/src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -91,6 +91,7 @@ void CLMeanStdDevNormalizationKernel::configure(const CLCompileContext &compile_ build_opts.add_option("-DVEC_SIZE=" + support::cpp11::to_string(num_elems_processed_per_iteration)); build_opts.add_option("-DEPSILON=" + float_to_string_with_full_precision(epsilon)); build_opts.add_option("-DWIDTH=" + support::cpp11::to_string(input->info()->dimension(0))); + build_opts.add_option_if(input->info()->data_type() == DataType::F16, "-DMEANSTDNORM_HALF"); build_opts.add_option_if(_run_in_place, "-DIN_PLACE"); // Create kernel |