diff options
Diffstat (limited to 'src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp | 25 |
1 files changed, 17 insertions, 8 deletions
diff --git a/src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp b/src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp index 9f98b67582..8632bdf623 100644 --- a/src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp +++ b/src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -29,6 +29,9 @@ #include "arm_compute/core/CL/OpenCL.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/utils/helpers/AdjustVecSize.h" +#include "arm_compute/core/utils/StringUtils.h" + #include "src/core/CL/CLValidate.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" @@ -47,7 +50,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, f ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); // Checks performed when output is configured - if((output != nullptr) && (output->total_size() != 0)) + if ((output != nullptr) && (output->total_size() != 0)) { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); @@ -59,6 +62,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, f CLMeanStdDevNormalizationKernel::CLMeanStdDevNormalizationKernel() : _input(nullptr), _output(nullptr), _run_in_place(false) { + _type = CLKernelType::ELEMENTWISE; } void CLMeanStdDevNormalizationKernel::configure(ICLTensor *input, ICLTensor *output, float epsilon) @@ -66,15 +70,19 @@ void CLMeanStdDevNormalizationKernel::configure(ICLTensor *input, ICLTensor *out configure(CLKernelLibrary::get().get_compile_context(), input, output, epsilon); } -void CLMeanStdDevNormalizationKernel::configure(const CLCompileContext &compile_context, ICLTensor *input, ICLTensor *output, float epsilon) +void CLMeanStdDevNormalizationKernel::configure(const CLCompileContext &compile_context, + ICLTensor *input, + ICLTensor *output, + float epsilon) { ARM_COMPUTE_ERROR_ON_NULLPTR(input); _run_in_place = (output == nullptr) || (output == input); - ARM_COMPUTE_ERROR_THROW_ON(CLMeanStdDevNormalizationKernel::validate(input->info(), (output != nullptr) ? output->info() : nullptr, epsilon)); + ARM_COMPUTE_ERROR_THROW_ON(CLMeanStdDevNormalizationKernel::validate( + input->info(), (output != nullptr) ? output->info() : nullptr, epsilon)); - if(output != nullptr) + if (output != nullptr) { auto_init_if_empty(*output->info(), *input->info()); } @@ -82,7 +90,8 @@ void CLMeanStdDevNormalizationKernel::configure(const CLCompileContext &compile_ _input = input; _output = output; - const unsigned int num_elems_processed_per_iteration = adjust_vec_size(16 / input->info()->element_size(), input->info()->dimension(0)); + const unsigned int num_elems_processed_per_iteration = + adjust_vec_size(16 / input->info()->element_size(), input->info()->dimension(0)); // Set build options CLBuildOptions build_opts; @@ -90,6 +99,7 @@ void CLMeanStdDevNormalizationKernel::configure(const CLCompileContext &compile_ build_opts.add_option("-DVEC_SIZE=" + support::cpp11::to_string(num_elems_processed_per_iteration)); build_opts.add_option("-DEPSILON=" + float_to_string_with_full_precision(epsilon)); build_opts.add_option("-DWIDTH=" + support::cpp11::to_string(input->info()->dimension(0))); + build_opts.add_option_if(input->info()->data_type() == DataType::F16, "-DMEANSTDNORM_HALF"); build_opts.add_option_if(_run_in_place, "-DIN_PLACE"); // Create kernel @@ -130,7 +140,6 @@ void CLMeanStdDevNormalizationKernel::run(const Window &window, cl::CommandQueue add_2D_tensor_argument_if((!_run_in_place), idx, _output, slice); enqueue(queue, *this, slice, lws_hint()); - } - while(window.slide_window_slice_2D(slice)); + } while (window.slide_window_slice_2D(slice)); } } // namespace arm_compute |