diff options
author | Giorgio Arena <giorgio.arena@arm.com> | 2017-11-27 15:50:10 +0000 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:41:36 +0000 |
commit | 7062382af52ee9b79576a0a962b5f5b9088b3bae (patch) | |
tree | 6b09e9aa31c1b85600c4f8f00fddae71a3202678 /src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp | |
parent | c000fb86c1462d9297d57ed4a5c92c210a869834 (diff) | |
download | ComputeLibrary-7062382af52ee9b79576a0a962b5f5b9088b3bae.tar.gz |
COMPMID-617 Add validation window to BatchNormalization, PixelwiseMultiplication, ArithmeticOps
Change-Id: I306bd23fcd9b7eb7a248dc762aae899b28300b90
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/110763
Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-by: Ioan-Cristian Szabo <ioan-cristian.szabo@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp | 80 |
1 files changed, 54 insertions, 26 deletions
diff --git a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp index f17091166c..38e367dfb7 100644 --- a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp +++ b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp @@ -37,6 +37,55 @@ using namespace arm_compute; +namespace +{ +Error validate_arguments(const ITensorInfo *input, const ITensorInfo *output, + const ITensorInfo *mean, const ITensorInfo *var, + const ITensorInfo *beta, const ITensorInfo *gamma, + float epsilon) +{ + ARM_COMPUTE_UNUSED(epsilon); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, var, beta, gamma); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var, beta, gamma); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, mean, var, beta, gamma); + ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(2) != mean->dimension(0)); + + if(output != nullptr && output->total_size() != 0) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output); + } + + return Error{}; +} + +std::pair<Error, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output) +{ + const unsigned int num_elems_processed_per_iteration = 16 / input->element_size(); + + // Configure kernel window + Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration)); + AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration); + + bool window_changed; + if(output != nullptr) + { + AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration); + window_changed = update_window_and_padding(win, input_access, output_access); + output_access.set_valid_region(win, input->valid_region()); + } + else + { + window_changed = update_window_and_padding(win, input_access); + } + + Error err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Error{}; + return std::make_pair(err, win); +} +} // namespace + CLBatchNormalizationLayerKernel::CLBatchNormalizationLayerKernel() : _input(nullptr), _output(nullptr), _mean(nullptr), _var(nullptr), _beta(nullptr), _gamma(nullptr), _epsilon(0) { @@ -85,19 +134,9 @@ void CLBatchNormalizationLayerKernel::configure(ICLTensor *input, ICLTensor *out _kernel.setArg<cl_float>(idx++, _epsilon); // Configure kernel window - Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration)); - AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration); - if(output != nullptr) - { - AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration); - update_window_and_padding(win, input_access, output_access); - output_access.set_valid_region(win, input->info()->valid_region()); - } - else - { - update_window_and_padding(win, input_access); - } - ICLKernel::configure(win); + auto win_config = validate_and_configure_window(input->info(), (output == nullptr) ? nullptr : output->info()); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + ICLKernel::configure(win_config.second); } Error CLBatchNormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, @@ -105,19 +144,8 @@ Error CLBatchNormalizationLayerKernel::validate(const ITensorInfo *input, const const ITensorInfo *beta, const ITensorInfo *gamma, float epsilon) { - ARM_COMPUTE_UNUSED(epsilon); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, var, beta, gamma); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var, beta, gamma); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, mean, var, beta, gamma); - ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(2) != mean->dimension(0)); - - if(output != nullptr && output->total_size() != 0) - { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output); - } + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, mean, var, beta, gamma, epsilon)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), (output == nullptr) ? nullptr : output->clone().get()).first); return Error{}; } |