From 754e9526a7caf50876c2db9563dc72f096093b34 Mon Sep 17 00:00:00 2001 From: Ioan-Cristian Szabo Date: Tue, 28 Nov 2017 18:29:43 +0000 Subject: COMPMID-617: Add validate support for NEON PixelWiseMultiplication Change-Id: Ie81a4d667146315fed7668cf2ca752d3bf49b0ab Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/111013 Reviewed-by: Anthony Barbier Reviewed-by: Michalis Spyrou Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com --- .../kernels/NEPixelWiseMultiplicationKernel.cpp | 126 +++++++++++++-------- 1 file changed, 79 insertions(+), 47 deletions(-) (limited to 'src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp') diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp index a2f3cffdf3..d765966960 100644 --- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp +++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp @@ -54,6 +54,68 @@ const float scale255_constant = 1.f / 255.f; const float32x4_t scale255_constant_f32q = vdupq_n_f32(scale255_constant); const float32x4_t positive_round_f32q = vdupq_n_f32(0.5f); +inline Error validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy) +{ + ARM_COMPUTE_UNUSED(overflow_policy); + ARM_COMPUTE_UNUSED(rounding_policy); + + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), + "Output can only be U8 if both inputs are U8"); + + if(is_data_type_fixed_point(input1->data_type()) || is_data_type_fixed_point(input2->data_type()) || is_data_type_fixed_point(output->data_type())) + { + // Check that all data types are the same and all fixed-point positions are the same + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2, output); + // Check if scale is representable in fixed-point with the provided settings + ARM_COMPUTE_RETURN_ERROR_ON_VALUE_NOT_REPRESENTABLE_IN_FIXED_POINT(scale, input1); + } + + if(std::abs(scale - scale255_constant) < 0.00001f) + { + ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_NEAREST_UP && rounding_policy != RoundingPolicy::TO_NEAREST_EVEN); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON(rounding_policy != RoundingPolicy::TO_ZERO); + + int exponent = 0; + const float normalized_mantissa = std::frexp(scale, &exponent); + + // Use int scaling if factor is equal to 1/2^n for 0 <= n <= 15 + // frexp returns 0.5 as mantissa which means that the exponent will be in the range of -1 <= e <= 14 + // Moreover, it will be negative as we deal with 1/2^n + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!((normalized_mantissa == 0.5f) && (-14 <= exponent) && (exponent <= 1)), "Scale value not supported (Should be 1/(2^n) or 1/255"); + } + + return Error{}; +} + +inline std::pair validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output) +{ + constexpr unsigned int num_elems_processed_per_iteration = 16; + + // Configure kernel window + Window win = calculate_max_window(*input1, Steps(num_elems_processed_per_iteration)); + AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration); + + bool window_changed = update_window_and_padding(win, + AccessWindowHorizontal(input1, 0, num_elems_processed_per_iteration), + AccessWindowHorizontal(input2, 0, num_elems_processed_per_iteration), + output_access); + + ValidRegion valid_region = intersect_valid_regions(input1->valid_region(), + input2->valid_region()); + + output_access.set_valid_region(win, valid_region); + + Error err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Error{}; + return std::make_pair(err, win); +} + /* Scales a given vector by 1/255. * * @note This does not work for all cases. e.g. for float of 0.49999999999999994 and large floats. @@ -443,6 +505,7 @@ NEPixelWiseMultiplicationKernel::NEPixelWiseMultiplicationKernel() void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy) { + ARM_COMPUTE_UNUSED(rounding_policy); ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output); // Auto initialize output if not initialized @@ -468,19 +531,7 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe } } - ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8), - "Output can only be U8 if both inputs are U8"); - if(is_data_type_fixed_point(input1->info()->data_type()) || is_data_type_fixed_point(input2->info()->data_type()) || is_data_type_fixed_point(output->info()->data_type())) - { - // Check that all data types are the same and all fixed-point positions are the same - ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2, output); - // Check if scale is representable in fixed-point with the provided settings - ARM_COMPUTE_ERROR_ON_VALUE_NOT_REPRESENTABLE_IN_FIXED_POINT(scale, input1); - } + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1->info(), input2->info(), output->info(), scale, overflow_policy, rounding_policy)); _input1 = input1; _input2 = input2; @@ -495,32 +546,17 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe // Check and validate scaling factor if(std::abs(scale - scale255_constant) < 0.00001f) { - ARM_COMPUTE_ERROR_ON(rounding_policy != RoundingPolicy::TO_NEAREST_UP && rounding_policy != RoundingPolicy::TO_NEAREST_EVEN); - ARM_COMPUTE_UNUSED(rounding_policy); - is_scale_255 = true; } else { - ARM_COMPUTE_ERROR_ON(rounding_policy != RoundingPolicy::TO_ZERO); - ARM_COMPUTE_UNUSED(rounding_policy); + int exponent = 0; - int exponent = 0; - const float normalized_mantissa = std::frexp(scale, &exponent); + std::frexp(scale, &exponent); - // Use int scaling if factor is equal to 1/2^n for 0 <= n <= 15 - // frexp returns 0.5 as mantissa which means that the exponent will be in the range of -1 <= e <= 14 - // Moreover, it will be negative as we deal with 1/2^n - if((normalized_mantissa == 0.5f) && (-14 <= exponent) && (exponent <= 1)) - { - // Store the positive exponent. We know that we compute 1/2^n - // Additionally we need to subtract 1 to compensate that frexp used a mantissa of 0.5 - _scale_exponent = std::abs(exponent - 1); - } - else - { - ARM_COMPUTE_ERROR("Scale value not supported (Should be 1/(2^n) or 1/255"); - } + // Store the positive exponent. We know that we compute 1/2^n + // Additionally we need to subtract 1 to compensate that frexp used a mantissa of 0.5 + _scale_exponent = std::abs(exponent - 1); } const DataType dt_input1 = input1->info()->data_type(); @@ -620,23 +656,19 @@ void NEPixelWiseMultiplicationKernel::configure(const ITensor *input1, const ITe ARM_COMPUTE_ERROR("You called with the wrong img formats"); } - constexpr unsigned int num_elems_processed_per_iteration = 16; - // Configure kernel window - Window win = calculate_max_window(*input1->info(), Steps(num_elems_processed_per_iteration)); - AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration); - - update_window_and_padding(win, - AccessWindowHorizontal(input1->info(), 0, num_elems_processed_per_iteration), - AccessWindowHorizontal(input2->info(), 0, num_elems_processed_per_iteration), - output_access); - - ValidRegion valid_region = intersect_valid_regions(input1->info()->valid_region(), - input2->info()->valid_region()); + auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info()); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + INEKernel::configure(win_config.second); +} - output_access.set_valid_region(win, valid_region); +Error NEPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy, + RoundingPolicy rounding_policy) +{ + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, scale, overflow_policy, rounding_policy)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), output->clone().get()).first); - INEKernel::configure(win); + return Error{}; } void NEPixelWiseMultiplicationKernel::run(const Window &window, const ThreadInfo &info) -- cgit v1.2.1