diff options
Diffstat (limited to 'src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp | 53 |
1 files changed, 38 insertions, 15 deletions
diff --git a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp index 95c8250ee7..62f21eed96 100644 --- a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp +++ b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp @@ -46,9 +46,22 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, { ARM_COMPUTE_UNUSED(epsilon); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, var, beta, gamma); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var, beta, gamma); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, mean, var, beta, gamma); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, var); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, mean, var); + if(beta != nullptr) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, beta); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, beta); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, beta); + } + if(gamma != nullptr) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, gamma); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, gamma); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, gamma); + } + ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(2) != mean->dimension(0)); if(act_info.enabled()) { @@ -108,7 +121,7 @@ CLBatchNormalizationLayerKernel::CLBatchNormalizationLayerKernel() void CLBatchNormalizationLayerKernel::configure(ICLTensor *input, ICLTensor *output, const ICLTensor *mean, const ICLTensor *var, const ICLTensor *beta, const ICLTensor *gamma, float epsilon, ActivationLayerInfo act_info) { - ARM_COMPUTE_ERROR_ON_NULLPTR(input, mean, var, beta, gamma); + ARM_COMPUTE_ERROR_ON_NULLPTR(input, mean, var); _input = input; _output = output; @@ -120,15 +133,9 @@ void CLBatchNormalizationLayerKernel::configure(ICLTensor *input, ICLTensor *out _run_in_place = (output == nullptr) || (output == input); - if(output != nullptr) - { - ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), output->info()); - // Output tensor auto initialization if not yet initialized - auto_init_if_empty(*output->info(), *input->info()->clone()); - } - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr, - mean->info(), var->info(), beta->info(), gamma->info(), epsilon, act_info)); + mean->info(), var->info(), (beta != nullptr) ? beta->info() : nullptr, + (gamma != nullptr) ? gamma->info() : nullptr, epsilon, act_info)); const unsigned int num_elems_processed_per_iteration = 16 / input->info()->element_size(); @@ -141,13 +148,23 @@ void CLBatchNormalizationLayerKernel::configure(ICLTensor *input, ICLTensor *out build_opts.add_option_if(act_info.enabled(), "-DB_VAL=" + float_to_string_with_full_precision(act_info.b())); build_opts.add_option_if(_run_in_place, "-DIN_PLACE"); build_opts.add_option_if(is_data_type_fixed_point(input->info()->data_type()), "-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input->info()->fixed_point_position())); + build_opts.add_option_if(beta == nullptr, "-DUSE_DEFAULT_BETA"); + build_opts.add_option_if(gamma == nullptr, "-DUSE_DEFAULT_GAMMA"); // Create kernel _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("batchnormalization_layer", build_opts.options())); // Set kernel static arguments unsigned int include_output = (!_run_in_place) ? 1 : 0; - unsigned int idx = (1 + include_output) * num_arguments_per_3D_tensor() + 4 * num_arguments_per_1D_tensor(); // Skip the input and output parameters + unsigned int idx = (1 + include_output) * num_arguments_per_3D_tensor() + 2 * num_arguments_per_1D_tensor(); // Skip the input and output parameters + if(_beta != nullptr) + { + idx += num_arguments_per_1D_tensor(); // Skip beta parameter + } + if(_gamma != nullptr) + { + idx += num_arguments_per_1D_tensor(); // Skip gamma parameter + } _kernel.setArg<cl_float>(idx++, _epsilon); // Configure kernel window @@ -191,8 +208,14 @@ void CLBatchNormalizationLayerKernel::run(const Window &window, cl::CommandQueue unsigned int idx = (1 + include_output) * num_arguments_per_3D_tensor(); add_1D_tensor_argument(idx, _mean, vector_slice); add_1D_tensor_argument(idx, _var, vector_slice); - add_1D_tensor_argument(idx, _beta, vector_slice); - add_1D_tensor_argument(idx, _gamma, vector_slice); + if(_beta != nullptr) + { + add_1D_tensor_argument(idx, _beta, vector_slice); + } + if(_gamma != nullptr) + { + add_1D_tensor_argument(idx, _gamma, vector_slice); + } do { |