From 9c450cc0e0b2e7060fa0a74a5196906bc28d0625 Mon Sep 17 00:00:00 2001 From: John Richardson Date: Wed, 22 Nov 2017 12:00:41 +0000 Subject: COMPMID-695: Update Phase and Validation Wrapping Simplify Phase reference implementation so that its results are more inline with the CL implementation (note: NEON uses a fast arctan approximation). Modify validate_wrap function to limit use to Integer types only. Change-Id: Ie4222568a8ef2587cab8e6d478745c5d0ded3d57 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/110192 Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com Reviewed-by: Anthony Barbier Reviewed-by: Gian Marco Iodice --- tests/validation/Validation.h | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) (limited to 'tests/validation/Validation.h') diff --git a/tests/validation/Validation.h b/tests/validation/Validation.h index 7a01085514..4a96dd34b6 100644 --- a/tests/validation/Validation.h +++ b/tests/validation/Validation.h @@ -236,15 +236,14 @@ void validate_keypoints(T target_first, T target_last, U reference_first, U refe template struct compare_base { - compare_base(typename T::value_type target, typename T::value_type reference, T tolerance = T(0), bool wrap_range = false) - : _target{ target }, _reference{ reference }, _tolerance{ tolerance }, _wrap_range{ wrap_range } + compare_base(typename T::value_type target, typename T::value_type reference, T tolerance = T(0)) + : _target{ target }, _reference{ reference }, _tolerance{ tolerance } { } typename T::value_type _target{}; typename T::value_type _reference{}; T _tolerance{}; - bool _wrap_range{}; }; template @@ -268,12 +267,6 @@ struct compare> : public compare_base> using comparison_type = typename std::conditional::value, int64_t, U>::type; - if(this->_wrap_range) - { - const comparison_type abs_difference(std::abs(static_cast(this->_target)) - std::abs(static_cast(this->_reference))); - return abs_difference <= static_cast(this->_tolerance); - } - const comparison_type abs_difference(std::abs(static_cast(this->_target) - static_cast(this->_reference))); return abs_difference <= static_cast(this->_tolerance); @@ -323,7 +316,7 @@ void validate(const IAccessor &tensor, const SimpleTensor &reference, U toler validate(tensor, reference, shape_to_valid_region(tensor.shape()), tolerance_value, tolerance_number); } -template +template ::value>::type> void validate_wrap(const IAccessor &tensor, const SimpleTensor &reference, U tolerance_value, float tolerance_number) { // Validate with valid region covering the entire shape @@ -391,7 +384,7 @@ void validate(const IAccessor &tensor, const SimpleTensor &reference, const V } } -template +template ::value>::type> void validate_wrap(const IAccessor &tensor, const SimpleTensor &reference, const ValidRegion &valid_region, U tolerance_value, float tolerance_number) { int64_t num_mismatches = 0; @@ -426,9 +419,23 @@ void validate_wrap(const IAccessor &tensor, const SimpleTensor &reference, co bool equal = compare(target_value, reference_value, tolerance_value); + // check for wrapping if(!equal) { - equal = compare(target_value, reference_value, tolerance_value, true); + if(!support::cpp11::isfinite(target_value) || !support::cpp11::isfinite(reference_value)) + { + equal = false; + } + else + { + using limits_type = typename std::make_unsigned::type; + + uint64_t max = std::numeric_limits::max(); + uint64_t abs_sum = std::abs(static_cast(target_value)) + std::abs(static_cast(reference_value)); + uint64_t wrap_difference = max - abs_sum; + + equal = wrap_difference < static_cast(tolerance_value); + } } if(!equal) @@ -437,7 +444,7 @@ void validate_wrap(const IAccessor &tensor, const SimpleTensor &reference, co ARM_COMPUTE_TEST_INFO("channel = " << c); ARM_COMPUTE_TEST_INFO("target = " << std::setprecision(5) << framework::make_printable(target_value)); ARM_COMPUTE_TEST_INFO("reference = " << std::setprecision(5) << framework::make_printable(reference_value)); - ARM_COMPUTE_TEST_INFO("tolerance = " << std::setprecision(5) << framework::make_printable(static_cast(tolerance_value))); + ARM_COMPUTE_TEST_INFO("wrap_tolerance = " << std::setprecision(5) << framework::make_printable(static_cast(tolerance_value))); framework::ARM_COMPUTE_PRINT_INFO(); ++num_mismatches; -- cgit v1.2.1