diff options
-rw-r--r-- | src/core/CL/kernels/CLArithmeticAdditionKernel.cpp | 81 | ||||
-rw-r--r-- | src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp | 79 | ||||
-rw-r--r-- | src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp | 80 | ||||
-rw-r--r-- | src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp | 108 | ||||
-rw-r--r-- | tests/validation/CL/ArithmeticAddition.cpp | 39 | ||||
-rw-r--r-- | tests/validation/CL/ArithmeticSubtraction.cpp | 39 | ||||
-rw-r--r-- | tests/validation/CL/BatchNormalizationLayer.cpp | 40 | ||||
-rw-r--r-- | tests/validation/CL/PixelWiseMultiplication.cpp | 53 |
8 files changed, 310 insertions, 209 deletions
diff --git a/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp b/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp index a7625f4303..831389e3b6 100644 --- a/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp +++ b/src/core/CL/kernels/CLArithmeticAdditionKernel.cpp @@ -41,6 +41,51 @@ using namespace arm_compute; +namespace +{ +Error validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy) +{ + ARM_COMPUTE_UNUSED(policy); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2); + + // Validate in case of configured output + if((output != nullptr) && (output->total_size() != 0)) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), + "Output can only be U8 if both inputs are U8"); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, output); + } + + return Error{}; +} + +std::pair<Error, Window> validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output) +{ + constexpr unsigned int num_elems_processed_per_iteration = 16; + + Window win = calculate_max_window(*input1, Steps(num_elems_processed_per_iteration)); + + AccessWindowHorizontal input1_access(input1, 0, num_elems_processed_per_iteration); + AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration); + AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration); + + bool window_changed = update_window_and_padding(win, input1_access, input2_access, output_access); + + ValidRegion valid_region = intersect_valid_regions(input1->valid_region(), + input2->valid_region()); + + output_access.set_valid_region(win, valid_region); + + Error err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Error{}; + return std::make_pair(err, win); +} +} // namespace + CLArithmeticAdditionKernel::CLArithmeticAdditionKernel() : _input1(nullptr), _input2(nullptr), _output(nullptr) { @@ -91,41 +136,15 @@ void CLArithmeticAdditionKernel::configure(const ICLTensor *input1, const ICLTen _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("arithmetic_add", build_opts)); // Configure kernel window - constexpr unsigned int num_elems_processed_per_iteration = 16; - - Window win = calculate_max_window(*input1->info(), Steps(num_elems_processed_per_iteration)); - - AccessWindowHorizontal input1_access(input1->info(), 0, num_elems_processed_per_iteration); - AccessWindowHorizontal input2_access(input2->info(), 0, num_elems_processed_per_iteration); - AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration); - - update_window_and_padding(win, input1_access, input2_access, output_access); - - ValidRegion valid_region = intersect_valid_regions(input1->info()->valid_region(), - input2->info()->valid_region()); - - output_access.set_valid_region(win, valid_region); - - ICLKernel::configure(win); + auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info()); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + ICLKernel::configure(win_config.second); } Error CLArithmeticAdditionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy) { - ARM_COMPUTE_UNUSED(policy); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2); - - // Validate in case of configured output - if((output != nullptr) && (output->total_size() != 0)) - { - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), - "Output can only be U8 if both inputs are U8"); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, output); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, output); - } + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, policy)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), output->clone().get()).first); return Error{}; } diff --git a/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp b/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp index 47d77ad8a9..5603451ca0 100644 --- a/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp +++ b/src/core/CL/kernels/CLArithmeticSubtractionKernel.cpp @@ -38,6 +38,50 @@ using namespace arm_compute; +namespace +{ +Error validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy) +{ + ARM_COMPUTE_UNUSED(policy); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2); + + // Validate in case of configured output + if((output != nullptr) && (output->total_size() != 0)) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), + "Output can only be U8 if both inputs are U8"); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, output); + } + + return Error{}; +} + +std::pair<Error, Window> validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output) +{ + constexpr unsigned int num_elems_processed_per_iteration = 16; + + Window win = calculate_max_window(*input1, Steps(num_elems_processed_per_iteration)); + AccessWindowHorizontal input1_access(input1, 0, num_elems_processed_per_iteration); + AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration); + AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration); + + bool window_changed = update_window_and_padding(win, input1_access, input2_access, output_access); + + ValidRegion valid_region = intersect_valid_regions(input1->valid_region(), + input2->valid_region()); + + output_access.set_valid_region(win, valid_region); + + Error err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Error{}; + return std::make_pair(err, win); +} +} // namespace + CLArithmeticSubtractionKernel::CLArithmeticSubtractionKernel() : _input1(nullptr), _input2(nullptr), _output(nullptr) { @@ -84,40 +128,15 @@ void CLArithmeticSubtractionKernel::configure(const ICLTensor *input1, const ICL _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("arithmetic_sub", build_opts)); // Configure kernel window - constexpr unsigned int num_elems_processed_per_iteration = 16; - - Window win = calculate_max_window(*input1->info(), Steps(num_elems_processed_per_iteration)); - AccessWindowHorizontal input1_access(input1->info(), 0, num_elems_processed_per_iteration); - AccessWindowHorizontal input2_access(input2->info(), 0, num_elems_processed_per_iteration); - AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration); - - update_window_and_padding(win, input1_access, input2_access, output_access); - - ValidRegion valid_region = intersect_valid_regions(input1->info()->valid_region(), - input2->info()->valid_region()); - - output_access.set_valid_region(win, valid_region); - - ICLKernel::configure(win); + auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info()); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + ICLKernel::configure(win_config.second); } Error CLArithmeticSubtractionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy) { - ARM_COMPUTE_UNUSED(policy); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2); - - // Validate in case of configured output - if((output != nullptr) && (output->total_size() != 0)) - { - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), - "Output can only be U8 if both inputs are U8"); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, output); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, output); - } + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, policy)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), output->clone().get()).first); return Error{}; } diff --git a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp index f17091166c..38e367dfb7 100644 --- a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp +++ b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp @@ -37,6 +37,55 @@ using namespace arm_compute; +namespace +{ +Error validate_arguments(const ITensorInfo *input, const ITensorInfo *output, + const ITensorInfo *mean, const ITensorInfo *var, + const ITensorInfo *beta, const ITensorInfo *gamma, + float epsilon) +{ + ARM_COMPUTE_UNUSED(epsilon); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, var, beta, gamma); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var, beta, gamma); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, mean, var, beta, gamma); + ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(2) != mean->dimension(0)); + + if(output != nullptr && output->total_size() != 0) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output); + } + + return Error{}; +} + +std::pair<Error, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output) +{ + const unsigned int num_elems_processed_per_iteration = 16 / input->element_size(); + + // Configure kernel window + Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration)); + AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration); + + bool window_changed; + if(output != nullptr) + { + AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration); + window_changed = update_window_and_padding(win, input_access, output_access); + output_access.set_valid_region(win, input->valid_region()); + } + else + { + window_changed = update_window_and_padding(win, input_access); + } + + Error err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Error{}; + return std::make_pair(err, win); +} +} // namespace + CLBatchNormalizationLayerKernel::CLBatchNormalizationLayerKernel() : _input(nullptr), _output(nullptr), _mean(nullptr), _var(nullptr), _beta(nullptr), _gamma(nullptr), _epsilon(0) { @@ -85,19 +134,9 @@ void CLBatchNormalizationLayerKernel::configure(ICLTensor *input, ICLTensor *out _kernel.setArg<cl_float>(idx++, _epsilon); // Configure kernel window - Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration)); - AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration); - if(output != nullptr) - { - AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration); - update_window_and_padding(win, input_access, output_access); - output_access.set_valid_region(win, input->info()->valid_region()); - } - else - { - update_window_and_padding(win, input_access); - } - ICLKernel::configure(win); + auto win_config = validate_and_configure_window(input->info(), (output == nullptr) ? nullptr : output->info()); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + ICLKernel::configure(win_config.second); } Error CLBatchNormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, @@ -105,19 +144,8 @@ Error CLBatchNormalizationLayerKernel::validate(const ITensorInfo *input, const const ITensorInfo *beta, const ITensorInfo *gamma, float epsilon) { - ARM_COMPUTE_UNUSED(epsilon); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mean, var, beta, gamma); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, mean, var, beta, gamma); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, mean, var, beta, gamma); - ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(2) != mean->dimension(0)); - - if(output != nullptr && output->total_size() != 0) - { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output); - } + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, mean, var, beta, gamma, epsilon)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), (output == nullptr) ? nullptr : output->clone().get()).first); return Error{}; } diff --git a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp index 5e35c8c1ff..a466fa41b4 100644 --- a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp +++ b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp @@ -40,6 +40,65 @@ using namespace arm_compute; +namespace +{ +Error validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, + ConvertPolicy overflow_policy, RoundingPolicy rounding_policy) +{ + ARM_COMPUTE_UNUSED(overflow_policy); + ARM_COMPUTE_UNUSED(rounding_policy); + + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale < 0, "Scale cannot be negative."); + + if(is_data_type_fixed_point(input1->data_type())) + { + // All data types must be all QS8 or all QS16 + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale != 1, "Unsupported scaling factor for QS8/QS16. Scale must be 1."); + } + + // Validate in case of configured output + if((output != nullptr) && (output->total_size() != 0)) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), + "Output can only be U8 if both inputs are U8"); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, output); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, output); + if(is_data_type_fixed_point(input1->data_type())) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, output); + } + } + + return Error{}; +} + +std::pair<Error, Window> validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output) +{ + constexpr unsigned int num_elems_processed_per_iteration = 16; + + Window win = calculate_max_window(*input1, Steps(num_elems_processed_per_iteration)); + + AccessWindowHorizontal input1_access(input1, 0, num_elems_processed_per_iteration); + AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration); + AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration); + + bool window_changed = update_window_and_padding(win, input1_access, input2_access, output_access); + + ValidRegion valid_region = intersect_valid_regions(input1->valid_region(), + input2->valid_region()); + output_access.set_valid_region(win, valid_region); + + Error err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Error{}; + return std::make_pair(err, win); +} +} // namespace + CLPixelWiseMultiplicationKernel::CLPixelWiseMultiplicationKernel() : _input1(nullptr), _input2(nullptr), _output(nullptr) { @@ -149,55 +208,16 @@ void CLPixelWiseMultiplicationKernel::configure(const ICLTensor *input1, const I } // Configure kernel window - constexpr unsigned int num_elems_processed_per_iteration = 16; - - Window win = calculate_max_window(*input1->info(), Steps(num_elems_processed_per_iteration)); - - AccessWindowHorizontal input1_access(input1->info(), 0, num_elems_processed_per_iteration); - AccessWindowHorizontal input2_access(input2->info(), 0, num_elems_processed_per_iteration); - AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration); - - update_window_and_padding(win, input1_access, input2_access, output_access); - - ValidRegion valid_region = intersect_valid_regions(input1->info()->valid_region(), - input2->info()->valid_region()); - output_access.set_valid_region(win, valid_region); - - ICLKernel::configure(win); + auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info()); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + ICLKernel::configure(win_config.second); } Error CLPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy) { - ARM_COMPUTE_UNUSED(overflow_policy); - ARM_COMPUTE_UNUSED(rounding_policy); - - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale < 0, "Scale cannot be negative."); - - if(is_data_type_fixed_point(input1->data_type())) - { - // All data types must be all QS8 or all QS16 - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale != 1, "Unsupported scaling factor for QS8/QS16. Scale must be 1."); - } - - // Validate in case of configured output - if((output != nullptr) && (output->total_size() != 0)) - { - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), - "Output can only be U8 if both inputs are U8"); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, output); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, output); - if(is_data_type_fixed_point(input1->data_type())) - { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input1, output); - } - } + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, scale, overflow_policy, rounding_policy)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), output->clone().get()).first); return Error{}; } diff --git a/tests/validation/CL/ArithmeticAddition.cpp b/tests/validation/CL/ArithmeticAddition.cpp index 19ccdaf111..61b3b06d74 100644 --- a/tests/validation/CL/ArithmeticAddition.cpp +++ b/tests/validation/CL/ArithmeticAddition.cpp @@ -64,31 +64,34 @@ TEST_SUITE(ArithmeticAddition) // *INDENT-OFF* // clang-format off DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( - framework::dataset::make("Input1Info", { TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), // Invalid data type combination - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Mismatching shapes - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 2), // Mismatching fixed point - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 2), + framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), // Window shrink + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), // Mismatching fixed point + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), }), - framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), + framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::S16), - TensorInfo(TensorShape(30U, 11U, 2U), 1, DataType::F32), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 3), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 2), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), + TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 3), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), })), - framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::S16), + framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(30U, 11U, 2U), 1, DataType::F32), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 3), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 2), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 3), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), })), - framework::dataset::make("Expected", { false, false, true, true, true, false })), + framework::dataset::make("Expected", { false, false, true, true, true, true, false })), input1_info, input2_info, output_info, expected) { - ARM_COMPUTE_EXPECT(bool(CLArithmeticAddition::validate(&input1_info, &input2_info, &output_info, ConvertPolicy::WRAP)) == expected, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(bool(CLArithmeticAddition::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), ConvertPolicy::WRAP)) == expected, framework::LogLevel::ERRORS); } // clang-format on // *INDENT-ON* diff --git a/tests/validation/CL/ArithmeticSubtraction.cpp b/tests/validation/CL/ArithmeticSubtraction.cpp index a068c8a357..9a290cfe30 100644 --- a/tests/validation/CL/ArithmeticSubtraction.cpp +++ b/tests/validation/CL/ArithmeticSubtraction.cpp @@ -71,31 +71,34 @@ TEST_SUITE(ArithmeticSubtraction) // *INDENT-OFF* // clang-format off DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( - framework::dataset::make("Input1Info", { TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), // Invalid data type combination - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Mismatching shapes - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 2), // Mismatching fixed point - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 2), + framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), // Window shrink + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), // Mismatching fixed point + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), }), - framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), + framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::S16), - TensorInfo(TensorShape(30U, 11U, 2U), 1, DataType::F32), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 3), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 2), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), + TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 3), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), })), - framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::S16), + framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(30U, 11U, 2U), 1, DataType::F32), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 3), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 2), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 3), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), })), - framework::dataset::make("Expected", { false, false, true, true, true, false })), + framework::dataset::make("Expected", { false, false, true, true, true, true, false })), input1_info, input2_info, output_info, expected) { - ARM_COMPUTE_EXPECT(bool(CLArithmeticSubtraction::validate(&input1_info, &input2_info, &output_info, ConvertPolicy::WRAP)) == expected, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(bool(CLArithmeticSubtraction::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), ConvertPolicy::WRAP)) == expected, framework::LogLevel::ERRORS); } // clang-format on // *INDENT-ON* diff --git a/tests/validation/CL/BatchNormalizationLayer.cpp b/tests/validation/CL/BatchNormalizationLayer.cpp index 6884131582..4976c1c1a6 100644 --- a/tests/validation/CL/BatchNormalizationLayer.cpp +++ b/tests/validation/CL/BatchNormalizationLayer.cpp @@ -81,35 +81,41 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(datasets::Ran // *INDENT-OFF* // clang-format off DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( - framework::dataset::make("InputInfo", { TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Mismatching data types - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Mismatching data types - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Invalid mean/var/beta/gamma shape - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 2), // Mismatching fixed point position - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 2), + framework::dataset::make("InputInfo", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Window shrink + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching data types + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching data types + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Invalid mean/var/beta/gamma shape + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), // Mismatching fixed point position + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), }), - framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), + framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F16), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 3), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 2), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F16), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 3), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), + TensorInfo(), })), framework::dataset::make("MVBGInfo",{ TensorInfo(TensorShape(2U), 1, DataType::F32), + TensorInfo(TensorShape(2U), 1, DataType::F32), TensorInfo(TensorShape(2U), 1, DataType::F16), TensorInfo(TensorShape(2U), 1, DataType::F32), TensorInfo(TensorShape(5U), 1, DataType::F32), TensorInfo(TensorShape(2U), 1, DataType::QS8, 2), TensorInfo(TensorShape(2U), 1, DataType::QS8, 2), + TensorInfo(TensorShape(2U), 1, DataType::QS8, 2), })), - framework::dataset::make("Expected", { false, true, true, true, true, false})), + framework::dataset::make("Expected", { false, true, true, true, true, true, false, false})), input_info, output_info, mvbg_info, expected) { - auto mean_info = mvbg_info; - auto var_info = mvbg_info; - auto beta_info = mvbg_info; - auto gamma_info = mvbg_info; - bool has_error = bool(CLBatchNormalizationLayer::validate(&input_info, &output_info, &mean_info, &var_info, &beta_info, &gamma_info, 1.f)); + const auto &mean_info = mvbg_info; + const auto &var_info = mvbg_info; + const auto &beta_info = mvbg_info; + const auto &gamma_info = mvbg_info; + bool has_error = bool(CLBatchNormalizationLayer::validate(&input_info.clone()->set_is_resizable(false), (output_info.total_size() == 0) ? nullptr : &output_info.clone()->set_is_resizable(false), &mean_info.clone()->set_is_resizable(false), &var_info.clone()->set_is_resizable(false), &beta_info.clone()->set_is_resizable(false), &gamma_info.clone()->set_is_resizable(false), 1.f)); ARM_COMPUTE_EXPECT(has_error == expected, framework::LogLevel::ERRORS); } // clang-format on diff --git a/tests/validation/CL/PixelWiseMultiplication.cpp b/tests/validation/CL/PixelWiseMultiplication.cpp index d49462a7f7..79e30c1754 100644 --- a/tests/validation/CL/PixelWiseMultiplication.cpp +++ b/tests/validation/CL/PixelWiseMultiplication.cpp @@ -93,38 +93,41 @@ TEST_SUITE(PixelWiseMultiplication) // *INDENT-OFF* // clang-format off DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( - framework::dataset::make("Input1Info", { TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), // Invalid scale - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), // Invalid data type combination - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::F32), // Mismatching shapes - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 2), // Mismatching data type - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 2), // Mismatching fixed point - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 2), // Invalid scale + framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), // Window shrink + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid scale + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), // Mismatching data type + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), // Mismatching fixed point + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), // Invalid scale }), - framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), + framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::S16), - TensorInfo(TensorShape(30U, 11U, 2U), 1, DataType::F32), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS16, 2), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 3), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 2), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), + TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS16, 2), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 3), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), })), - framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::S16), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), + framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(30U, 11U, 2U), 1, DataType::F32), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 3), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 2), - TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::QS8, 2), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 3), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), + TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), })), - framework::dataset::make("Scale",{ 2.f, 2.f, -1.f, 1.f, 1.f, 1.f, 1.f, 3.f})), - framework::dataset::make("Expected", { false, false, true, true, true, true, true, true })), + framework::dataset::make("Scale",{ 2.f, 2.f, 2.f, -1.f, 1.f, 1.f, 1.f, 1.f, 3.f})), + framework::dataset::make("Expected", { false, false, true, true, true, true, true, true, true })), input1_info, input2_info, output_info, scale, expected) { - bool has_error = bool(CLPixelWiseMultiplication::validate(&input1_info, &input2_info, &output_info, scale, ConvertPolicy::WRAP, RoundingPolicy::TO_ZERO)); + bool has_error = bool(CLPixelWiseMultiplication::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), scale, ConvertPolicy::WRAP, RoundingPolicy::TO_ZERO)); ARM_COMPUTE_EXPECT(has_error == expected, framework::LogLevel::ERRORS); } // clang-format on |