diff options
author | Giorgio Arena <giorgio.arena@arm.com> | 2017-11-27 15:50:10 +0000 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:41:36 +0000 |
commit | 7062382af52ee9b79576a0a962b5f5b9088b3bae (patch) | |
tree | 6b09e9aa31c1b85600c4f8f00fddae71a3202678 /src | |
parent | c000fb86c1462d9297d57ed4a5c92c210a869834 (diff) | |
download | ComputeLibrary-7062382af52ee9b79576a0a962b5f5b9088b3bae.tar.gz |
COMPMID-617 Add validation window to BatchNormalization, PixelwiseMultiplication, ArithmeticOps
Change-Id: I306bd23fcd9b7eb7a248dc762aae899b28300b90
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/110763
Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-by: Ioan-Cristian Szabo <ioan-cristian.szabo@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src')
4 files changed, 217 insertions, 131 deletions
diff --git a/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp b/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp index a7625f4303..831389e3b6 100644 --- a/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp +++ b/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp @@ -41,6 +41,51 @@ using namespace arm_compute; +namespace +{ +Error validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy) +{ + ARM_COMPUTE_UNUSED(policy); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2); + + // Validate in case of configured output + if((output != nullptr) && (output->total_size() != 0)) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), + "Output can only be U8 if both inputs are U8"); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, output); + } + + return Error{}; +} + +std::pair<Error, Window> validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output) +{ + constexpr unsigned int num_elems_processed_per_iteration = 16; + + Window win = calculate_max_window(*input1, Steps(num_elems_processed_per_iteration)); + + AccessWindowHorizontal input1_access(input1, 0, num_elems_processed_per_iteration); + AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration); + AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration); + + bool window_changed = update_window_and_padding(win, input1_access, input2_access, output_access); + + ValidRegion valid_region = intersect_valid_regions(input1->valid_region(), + input2->valid_region()); + + output_access.set_valid_region(win, valid_region); + + Error err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Error{}; + return std::make_pair(err, win); +} +} // namespace + CLArithmeticAdditionKernel::CLArithmeticAdditionKernel() : _input1(nullptr), _input2(nullptr), _output(nullptr) { @@ -91,41 +136,15 @@ void CLArithmeticAdditionKernel::configure(const ICLTensor *input1, const ICLTen _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("arithmetic_add", build_opts)); // Configure kernel window - constexpr unsigned int num_elems_processed_per_iteration = 16; - - Window win = calculate_max_window(*input1->info(), Steps(num_elems_processed_per_iteration)); - - AccessWindowHorizontal input1_access(input1->info(), 0, num_elems_processed_per_iteration); - AccessWindowHorizontal input2_access(input2->info(), 0, num_elems_processed_per_iteration); - AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration); - - update_window_and_padding(win, input1_access, input2_access, output_access); - - ValidRegion valid_region = intersect_valid_regions(input1->info()->valid_region(), - input2->info()->valid_region()); - - output_access.set_valid_region(win, valid_region); - - ICLKernel::configure(win); + auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info()); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + ICLKernel::configure(win_config.second); } Error CLArithmeticAdditionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy) { - ARM_COMPUTE_UNUSED(policy); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2); - - // Validate in case of configured output - if((output != nullptr) && (output->total_size() != 0)) - { - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), - "Output can only be U8 if both inputs are U8"); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, output); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, output); - } + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, policy)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), output->clone().get()).first); return Error{}; } diff --git a/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp b/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp index 47d77ad8a9..5603451ca0 100644 --- a/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp +++ b/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp @@ -38,6 +38,50 @@ using namespace arm_compute; +namespace +{ +Error validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy) +{ + ARM_COMPUTE_UNUSED(policy); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2); + + // Validate in case of configured output + if((output != nullptr) && (output->total_size() != 0)) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), + "Output can only be U8 if both inputs are U8"); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, output); + } + + return Error{}; +} + +std::pair<Error, Window> validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output) +{ + constexpr unsigned int num_elems_processed_per_iteration = 16; + + Window win = calculate_max_window(*input1, Steps(num_elems_processed_per_iteration)); + AccessWindowHorizontal input1_access(input1, 0, num_elems_processed_per_iteration); + AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration); + AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration); + + bool window_changed = update_window_and_padding(win, input1_access, input2_access, output_access); + + ValidRegion valid_region = intersect_valid_regions(input1->valid_region(), + input2->valid_region()); + + output_access.set_valid_region(win, valid_region); + + Error err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Error{}; + return std::make_pair(err, win); +} +} // namespace + CLArithmeticSubtractionKernel::CLArithmeticSubtractionKernel() : _input1(nullptr), _input2(nullptr), _output(nullptr) { @@ -84,40 +128,15 @@ void CLArithmeticSubtractionKernel::configure(const ICLTensor *input1, const ICL _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("arithmetic_sub", build_opts)); // Configure kernel window - constexpr unsigned int num_elems_processed_per_iteration = 16; - - Window win = calculate_max_window(*input1->info(), Steps(num_elems_processed_per_iteration)); - AccessWindowHorizontal input1_access(input1->info(), 0, num_elems_processed_per_iteration); - AccessWindowHorizontal input2_access(input2->info(), 0, num_elems_processed_per_iteration); - AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration); - - update_window_and_padding(win, input1_access, input2_access, output_access); - - ValidRegion valid_region = intersect_valid_regions(input1->info()->valid_region(), - input2->info()->valid_region()); - - output_access.set_valid_region(win, valid_region); - - ICLKernel::configure(win); + auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info()); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + ICLKernel::configure(win_config.second); } Error CLArithmeticSubtractionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy) { - ARM_COMPUTE_UNUSED(policy); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2); - - // Validate in case of configured output - if((output != nullptr) && (output->total_size() != 0)) - { - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), - "Output can only be U8 if both inputs are U8"); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, output); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, output); - } + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, policy)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), output->clone().get()).first); return Error{}; } diff --git a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp index f17091166c..38e367dfb7 100644 --- a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp +++ b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp @@ -37,6 +37,55 @@ using namespace arm_compute; +namespace +{ +Error validate_arguments(const ITensorInfo *input, const ITensorInfo *output, + const ITensorInfo *mean, const ITensorInfo *var, + const ITensorInfo *beta, const ITensorInfo *gamma, + float epsilon) +{ + ARM_COMPUTE_UNUSED(epsilon); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, var, beta, gamma); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var, beta, gamma); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, mean, var, beta, gamma); + ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(2) != mean->dimension(0)); + + if(output != nullptr && output->total_size() != 0) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output); + } + + return Error{}; +} + +std::pair<Error, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output) +{ + const unsigned int num_elems_processed_per_iteration = 16 / input->element_size(); + + // Configure kernel window + Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration)); + AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration); + + bool window_changed; + if(output != nullptr) + { + AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration); + window_changed = update_window_and_padding(win, input_access, output_access); + output_access.set_valid_region(win, input->valid_region()); + } + else + { + window_changed = update_window_and_padding(win, input_access); + } + + Error err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Error{}; + return std::make_pair(err, win); +} +} // namespace + CLBatchNormalizationLayerKernel::CLBatchNormalizationLayerKernel() : _input(nullptr), _output(nullptr), _mean(nullptr), _var(nullptr), _beta(nullptr), _gamma(nullptr), _epsilon(0) { @@ -85,19 +134,9 @@ void CLBatchNormalizationLayerKernel::configure(ICLTensor *input, ICLTensor *out _kernel.setArg<cl_float>(idx++, _epsilon); // Configure kernel window - Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration)); - AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration); - if(output != nullptr) - { - AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration); - update_window_and_padding(win, input_access, output_access); - output_access.set_valid_region(win, input->info()->valid_region()); - } - else - { - update_window_and_padding(win, input_access); - } - ICLKernel::configure(win); + auto win_config = validate_and_configure_window(input->info(), (output == nullptr) ? nullptr : output->info()); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + ICLKernel::configure(win_config.second); } Error CLBatchNormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, @@ -105,19 +144,8 @@ Error CLBatchNormalizationLayerKernel::validate(const ITensorInfo *input, const const ITensorInfo *beta, const ITensorInfo *gamma, float epsilon) { - ARM_COMPUTE_UNUSED(epsilon); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, var, beta, gamma); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var, beta, gamma); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, mean, var, beta, gamma); - ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(2) != mean->dimension(0)); - - if(output != nullptr && output->total_size() != 0) - { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output); - } + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, mean, var, beta, gamma, epsilon)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), (output == nullptr) ? nullptr : output->clone().get()).first); return Error{}; } diff --git a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp index 5e35c8c1ff..a466fa41b4 100644 --- a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp +++ b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp @@ -40,6 +40,65 @@ using namespace arm_compute; +namespace +{ +Error validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, + ConvertPolicy overflow_policy, RoundingPolicy rounding_policy) +{ + ARM_COMPUTE_UNUSED(overflow_policy); + ARM_COMPUTE_UNUSED(rounding_policy); + + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale < 0, "Scale cannot be negative."); + + if(is_data_type_fixed_point(input1->data_type())) + { + // All data types must be all QS8 or all QS16 + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale != 1, "Unsupported scaling factor for QS8/QS16. Scale must be 1."); + } + + // Validate in case of configured output + if((output != nullptr) && (output->total_size() != 0)) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), + "Output can only be U8 if both inputs are U8"); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, output); + if(is_data_type_fixed_point(input1->data_type())) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, output); + } + } + + return Error{}; +} + +std::pair<Error, Window> validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output) +{ + constexpr unsigned int num_elems_processed_per_iteration = 16; + + Window win = calculate_max_window(*input1, Steps(num_elems_processed_per_iteration)); + + AccessWindowHorizontal input1_access(input1, 0, num_elems_processed_per_iteration); + AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration); + AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration); + + bool window_changed = update_window_and_padding(win, input1_access, input2_access, output_access); + + ValidRegion valid_region = intersect_valid_regions(input1->valid_region(), + input2->valid_region()); + output_access.set_valid_region(win, valid_region); + + Error err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Error{}; + return std::make_pair(err, win); +} +} // namespace + CLPixelWiseMultiplicationKernel::CLPixelWiseMultiplicationKernel() : _input1(nullptr), _input2(nullptr), _output(nullptr) { @@ -149,55 +208,16 @@ void CLPixelWiseMultiplicationKernel::configure(const ICLTensor *input1, const I } // Configure kernel window - constexpr unsigned int num_elems_processed_per_iteration = 16; - - Window win = calculate_max_window(*input1->info(), Steps(num_elems_processed_per_iteration)); - - AccessWindowHorizontal input1_access(input1->info(), 0, num_elems_processed_per_iteration); - AccessWindowHorizontal input2_access(input2->info(), 0, num_elems_processed_per_iteration); - AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration); - - update_window_and_padding(win, input1_access, input2_access, output_access); - - ValidRegion valid_region = intersect_valid_regions(input1->info()->valid_region(), - input2->info()->valid_region()); - output_access.set_valid_region(win, valid_region); - - ICLKernel::configure(win); + auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info()); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + ICLKernel::configure(win_config.second); } Error CLPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy) { - ARM_COMPUTE_UNUSED(overflow_policy); - ARM_COMPUTE_UNUSED(rounding_policy); - - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale < 0, "Scale cannot be negative."); - - if(is_data_type_fixed_point(input1->data_type())) - { - // All data types must be all QS8 or all QS16 - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale != 1, "Unsupported scaling factor for QS8/QS16. Scale must be 1."); - } - - // Validate in case of configured output - if((output != nullptr) && (output->total_size() != 0)) - { - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), - "Output can only be U8 if both inputs are U8"); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, output); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, output); - if(is_data_type_fixed_point(input1->data_type())) - { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, output); - } - } + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, scale, overflow_policy, rounding_policy)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), output->clone().get()).first); return Error{}; } |