aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLInstanceNormalizationLayerKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLInstanceNormalizationLayerKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLInstanceNormalizationLayerKernel.cpp33
1 files changed, 14 insertions, 19 deletions
diff --git a/src/core/CL/kernels/CLInstanceNormalizationLayerKernel.cpp b/src/core/CL/kernels/CLInstanceNormalizationLayerKernel.cpp
index 0f208573a1..5c2a3d993c 100644
--- a/src/core/CL/kernels/CLInstanceNormalizationLayerKernel.cpp
+++ b/src/core/CL/kernels/CLInstanceNormalizationLayerKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2019-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -38,12 +38,9 @@ namespace arm_compute
{
namespace
{
-Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, float gamma, float beta, float epsilon)
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const InstanceNormalizationLayerKernelInfo &info)
{
- ARM_COMPUTE_UNUSED(gamma);
- ARM_COMPUTE_UNUSED(beta);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(epsilon == 0.f, "Epsilon must be different than 0");
-
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(info.epsilon == 0.f, "Epsilon must be different than 0");
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(input, DataType::F16, DataType::F32);
if(output != nullptr && output->total_size() != 0)
@@ -74,33 +71,31 @@ std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITe
} // namespace
CLInstanceNormalizationLayerKernel::CLInstanceNormalizationLayerKernel()
- : _input(nullptr), _output(nullptr), _gamma(1), _beta(0), _epsilon(1e-12), _run_in_place(false)
+ : _input(nullptr), _output(nullptr), _run_in_place(false)
{
}
-void CLInstanceNormalizationLayerKernel::configure(ICLTensor *input, ICLTensor *output, float gamma, float beta, float epsilon)
+void CLInstanceNormalizationLayerKernel::configure(ICLTensor *input, ICLTensor *output, const InstanceNormalizationLayerKernelInfo &info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input);
- _input = input;
- _output = output == nullptr ? input : output;
- _gamma = gamma;
- _beta = beta;
- _epsilon = epsilon;
+ _input = input;
+ _output = output == nullptr ? input : output;
_run_in_place = (output == nullptr) || (output == input);
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(_input->info(), _output->info(), gamma, beta, epsilon));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(_input->info(), _output->info(), info));
const unsigned int num_elems_processed_per_iteration = 16 / input->info()->element_size();
CLBuildOptions build_opts;
build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
+ build_opts.add_option("-DINTERNAL_DATA_TYPE=" + (info.use_mixed_precision ? "float" : get_cl_type_from_data_type(input->info()->data_type())));
build_opts.add_option("-DVEC_SIZE=" + support::cpp11::to_string(num_elems_processed_per_iteration));
build_opts.add_option("-DDIM_X=" + support::cpp11::to_string(input->info()->dimension(0)));
build_opts.add_option("-DDIM_Y=" + support::cpp11::to_string(input->info()->dimension(1)));
build_opts.add_option("-DDIM_Z=" + support::cpp11::to_string(input->info()->dimension(2)));
- build_opts.add_option("-DGAMMA=" + float_to_string_with_full_precision(gamma));
- build_opts.add_option("-DBETA=" + float_to_string_with_full_precision(beta));
- build_opts.add_option("-DEPSILON=" + float_to_string_with_full_precision(epsilon));
+ build_opts.add_option("-DGAMMA=" + float_to_string_with_full_precision(info.gamma));
+ build_opts.add_option("-DBETA=" + float_to_string_with_full_precision(info.beta));
+ build_opts.add_option("-DEPSILON=" + float_to_string_with_full_precision(info.epsilon));
build_opts.add_option_if(_run_in_place, "-DIN_PLACE");
build_opts.add_option_if(_input->info()->data_layout() == DataLayout::NHWC, "-DNHWC");
@@ -113,9 +108,9 @@ void CLInstanceNormalizationLayerKernel::configure(ICLTensor *input, ICLTensor *
ICLKernel::configure_internal(std::get<1>(win_config));
}
-Status CLInstanceNormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, float gamma, float beta, float epsilon)
+Status CLInstanceNormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const InstanceNormalizationLayerKernelInfo &info)
{
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, gamma, beta, epsilon));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, info));
ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), (output == nullptr ? input->clone().get() : output->clone().get()))));
return Status{};
}