aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp53
1 files changed, 38 insertions, 15 deletions
diff --git a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
index 95c8250ee7..62f21eed96 100644
--- a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
+++ b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
@@ -46,9 +46,22 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output,
{
ARM_COMPUTE_UNUSED(epsilon);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, var, beta, gamma);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var, beta, gamma);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, mean, var, beta, gamma);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, var);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, mean, var);
+ if(beta != nullptr)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, beta);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, beta);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, beta);
+ }
+ if(gamma != nullptr)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, gamma);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, gamma);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, gamma);
+ }
+
ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(2) != mean->dimension(0));
if(act_info.enabled())
{
@@ -108,7 +121,7 @@ CLBatchNormalizationLayerKernel::CLBatchNormalizationLayerKernel()
void CLBatchNormalizationLayerKernel::configure(ICLTensor *input, ICLTensor *output, const ICLTensor *mean, const ICLTensor *var, const ICLTensor *beta, const ICLTensor *gamma,
float epsilon, ActivationLayerInfo act_info)
{
- ARM_COMPUTE_ERROR_ON_NULLPTR(input, mean, var, beta, gamma);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input, mean, var);
_input = input;
_output = output;
@@ -120,15 +133,9 @@ void CLBatchNormalizationLayerKernel::configure(ICLTensor *input, ICLTensor *out
_run_in_place = (output == nullptr) || (output == input);
- if(output != nullptr)
- {
- ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), output->info());
- // Output tensor auto initialization if not yet initialized
- auto_init_if_empty(*output->info(), *input->info()->clone());
- }
-
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr,
- mean->info(), var->info(), beta->info(), gamma->info(), epsilon, act_info));
+ mean->info(), var->info(), (beta != nullptr) ? beta->info() : nullptr,
+ (gamma != nullptr) ? gamma->info() : nullptr, epsilon, act_info));
const unsigned int num_elems_processed_per_iteration = 16 / input->info()->element_size();
@@ -141,13 +148,23 @@ void CLBatchNormalizationLayerKernel::configure(ICLTensor *input, ICLTensor *out
build_opts.add_option_if(act_info.enabled(), "-DB_VAL=" + float_to_string_with_full_precision(act_info.b()));
build_opts.add_option_if(_run_in_place, "-DIN_PLACE");
build_opts.add_option_if(is_data_type_fixed_point(input->info()->data_type()), "-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input->info()->fixed_point_position()));
+ build_opts.add_option_if(beta == nullptr, "-DUSE_DEFAULT_BETA");
+ build_opts.add_option_if(gamma == nullptr, "-DUSE_DEFAULT_GAMMA");
// Create kernel
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("batchnormalization_layer", build_opts.options()));
// Set kernel static arguments
unsigned int include_output = (!_run_in_place) ? 1 : 0;
- unsigned int idx = (1 + include_output) * num_arguments_per_3D_tensor() + 4 * num_arguments_per_1D_tensor(); // Skip the input and output parameters
+ unsigned int idx = (1 + include_output) * num_arguments_per_3D_tensor() + 2 * num_arguments_per_1D_tensor(); // Skip the input and output parameters
+ if(_beta != nullptr)
+ {
+ idx += num_arguments_per_1D_tensor(); // Skip beta parameter
+ }
+ if(_gamma != nullptr)
+ {
+ idx += num_arguments_per_1D_tensor(); // Skip gamma parameter
+ }
_kernel.setArg<cl_float>(idx++, _epsilon);
// Configure kernel window
@@ -191,8 +208,14 @@ void CLBatchNormalizationLayerKernel::run(const Window &window, cl::CommandQueue
unsigned int idx = (1 + include_output) * num_arguments_per_3D_tensor();
add_1D_tensor_argument(idx, _mean, vector_slice);
add_1D_tensor_argument(idx, _var, vector_slice);
- add_1D_tensor_argument(idx, _beta, vector_slice);
- add_1D_tensor_argument(idx, _gamma, vector_slice);
+ if(_beta != nullptr)
+ {
+ add_1D_tensor_argument(idx, _beta, vector_slice);
+ }
+ if(_gamma != nullptr)
+ {
+ add_1D_tensor_argument(idx, _gamma, vector_slice);
+ }
do
{