aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp25
1 files changed, 17 insertions, 8 deletions
diff --git a/src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp b/src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp
index 9f98b67582..8632bdf623 100644
--- a/src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp
+++ b/src/core/CL/kernels/CLMeanStdDevNormalizationKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2021 Arm Limited.
+ * Copyright (c) 2019-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -29,6 +29,9 @@
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/utils/helpers/AdjustVecSize.h"
+#include "arm_compute/core/utils/StringUtils.h"
+
#include "src/core/CL/CLValidate.h"
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
@@ -47,7 +50,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, f
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
// Checks performed when output is configured
- if((output != nullptr) && (output->total_size() != 0))
+ if ((output != nullptr) && (output->total_size() != 0))
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
@@ -59,6 +62,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, f
CLMeanStdDevNormalizationKernel::CLMeanStdDevNormalizationKernel()
: _input(nullptr), _output(nullptr), _run_in_place(false)
{
+ _type = CLKernelType::ELEMENTWISE;
}
void CLMeanStdDevNormalizationKernel::configure(ICLTensor *input, ICLTensor *output, float epsilon)
@@ -66,15 +70,19 @@ void CLMeanStdDevNormalizationKernel::configure(ICLTensor *input, ICLTensor *out
configure(CLKernelLibrary::get().get_compile_context(), input, output, epsilon);
}
-void CLMeanStdDevNormalizationKernel::configure(const CLCompileContext &compile_context, ICLTensor *input, ICLTensor *output, float epsilon)
+void CLMeanStdDevNormalizationKernel::configure(const CLCompileContext &compile_context,
+ ICLTensor *input,
+ ICLTensor *output,
+ float epsilon)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input);
_run_in_place = (output == nullptr) || (output == input);
- ARM_COMPUTE_ERROR_THROW_ON(CLMeanStdDevNormalizationKernel::validate(input->info(), (output != nullptr) ? output->info() : nullptr, epsilon));
+ ARM_COMPUTE_ERROR_THROW_ON(CLMeanStdDevNormalizationKernel::validate(
+ input->info(), (output != nullptr) ? output->info() : nullptr, epsilon));
- if(output != nullptr)
+ if (output != nullptr)
{
auto_init_if_empty(*output->info(), *input->info());
}
@@ -82,7 +90,8 @@ void CLMeanStdDevNormalizationKernel::configure(const CLCompileContext &compile_
_input = input;
_output = output;
- const unsigned int num_elems_processed_per_iteration = adjust_vec_size(16 / input->info()->element_size(), input->info()->dimension(0));
+ const unsigned int num_elems_processed_per_iteration =
+ adjust_vec_size(16 / input->info()->element_size(), input->info()->dimension(0));
// Set build options
CLBuildOptions build_opts;
@@ -90,6 +99,7 @@ void CLMeanStdDevNormalizationKernel::configure(const CLCompileContext &compile_
build_opts.add_option("-DVEC_SIZE=" + support::cpp11::to_string(num_elems_processed_per_iteration));
build_opts.add_option("-DEPSILON=" + float_to_string_with_full_precision(epsilon));
build_opts.add_option("-DWIDTH=" + support::cpp11::to_string(input->info()->dimension(0)));
+ build_opts.add_option_if(input->info()->data_type() == DataType::F16, "-DMEANSTDNORM_HALF");
build_opts.add_option_if(_run_in_place, "-DIN_PLACE");
// Create kernel
@@ -130,7 +140,6 @@ void CLMeanStdDevNormalizationKernel::run(const Window &window, cl::CommandQueue
add_2D_tensor_argument_if((!_run_in_place), idx, _output, slice);
enqueue(queue, *this, slice, lws_hint());
- }
- while(window.slide_window_slice_2D(slice));
+ } while (window.slide_window_slice_2D(slice));
}
} // namespace arm_compute