From 7062382af52ee9b79576a0a962b5f5b9088b3bae Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Mon, 27 Nov 2017 15:50:10 +0000 Subject: COMPMID-617 Add validation window to BatchNormalization, PixelwiseMultiplication, ArithmeticOps Change-Id: I306bd23fcd9b7eb7a248dc762aae899b28300b90 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/110763 Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com Reviewed-by: Georgios Pinitas Reviewed-by: Ioan-Cristian Szabo Reviewed-by: Anthony Barbier --- .../CL/kernels/CLPixelWiseMultiplicationKernel.cpp | 108 ++++++++++++--------- 1 file changed, 64 insertions(+), 44 deletions(-) (limited to 'src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp') diff --git a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp index 5e35c8c1ff..a466fa41b4 100644 --- a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp +++ b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp @@ -40,6 +40,65 @@ using namespace arm_compute; +namespace +{ +Error validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, + ConvertPolicy overflow_policy, RoundingPolicy rounding_policy) +{ + ARM_COMPUTE_UNUSED(overflow_policy); + ARM_COMPUTE_UNUSED(rounding_policy); + + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale < 0, "Scale cannot be negative."); + + if(is_data_type_fixed_point(input1->data_type())) + { + // All data types must be all QS8 or all QS16 + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale != 1, "Unsupported scaling factor for QS8/QS16. Scale must be 1."); + } + + // Validate in case of configured output + if((output != nullptr) && (output->total_size() != 0)) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), + "Output can only be U8 if both inputs are U8"); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, output); + if(is_data_type_fixed_point(input1->data_type())) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, output); + } + } + + return Error{}; +} + +std::pair validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output) +{ + constexpr unsigned int num_elems_processed_per_iteration = 16; + + Window win = calculate_max_window(*input1, Steps(num_elems_processed_per_iteration)); + + AccessWindowHorizontal input1_access(input1, 0, num_elems_processed_per_iteration); + AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration); + AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration); + + bool window_changed = update_window_and_padding(win, input1_access, input2_access, output_access); + + ValidRegion valid_region = intersect_valid_regions(input1->valid_region(), + input2->valid_region()); + output_access.set_valid_region(win, valid_region); + + Error err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Error{}; + return std::make_pair(err, win); +} +} // namespace + CLPixelWiseMultiplicationKernel::CLPixelWiseMultiplicationKernel() : _input1(nullptr), _input2(nullptr), _output(nullptr) { @@ -149,55 +208,16 @@ void CLPixelWiseMultiplicationKernel::configure(const ICLTensor *input1, const I } // Configure kernel window - constexpr unsigned int num_elems_processed_per_iteration = 16; - - Window win = calculate_max_window(*input1->info(), Steps(num_elems_processed_per_iteration)); - - AccessWindowHorizontal input1_access(input1->info(), 0, num_elems_processed_per_iteration); - AccessWindowHorizontal input2_access(input2->info(), 0, num_elems_processed_per_iteration); - AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration); - - update_window_and_padding(win, input1_access, input2_access, output_access); - - ValidRegion valid_region = intersect_valid_regions(input1->info()->valid_region(), - input2->info()->valid_region()); - output_access.set_valid_region(win, valid_region); - - ICLKernel::configure(win); + auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info()); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + ICLKernel::configure(win_config.second); } Error CLPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy) { - ARM_COMPUTE_UNUSED(overflow_policy); - ARM_COMPUTE_UNUSED(rounding_policy); - - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale < 0, "Scale cannot be negative."); - - if(is_data_type_fixed_point(input1->data_type())) - { - // All data types must be all QS8 or all QS16 - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale != 1, "Unsupported scaling factor for QS8/QS16. Scale must be 1."); - } - - // Validate in case of configured output - if((output != nullptr) && (output->total_size() != 0)) - { - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), - "Output can only be U8 if both inputs are U8"); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, output); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, output); - if(is_data_type_fixed_point(input1->data_type())) - { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, output); - } - } + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, scale, overflow_policy, rounding_policy)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), output->clone().get()).first); return Error{}; } -- cgit v1.2.1