diff options
Diffstat (limited to 'tests/validation/Validation.h')
-rw-r--r-- | tests/validation/Validation.h | 33 |
1 files changed, 20 insertions, 13 deletions
diff --git a/tests/validation/Validation.h b/tests/validation/Validation.h index 7a01085514..4a96dd34b6 100644 --- a/tests/validation/Validation.h +++ b/tests/validation/Validation.h @@ -236,15 +236,14 @@ void validate_keypoints(T target_first, T target_last, U reference_first, U refe template <typename T> struct compare_base { - compare_base(typename T::value_type target, typename T::value_type reference, T tolerance = T(0), bool wrap_range = false) - : _target{ target }, _reference{ reference }, _tolerance{ tolerance }, _wrap_range{ wrap_range } + compare_base(typename T::value_type target, typename T::value_type reference, T tolerance = T(0)) + : _target{ target }, _reference{ reference }, _tolerance{ tolerance } { } typename T::value_type _target{}; typename T::value_type _reference{}; T _tolerance{}; - bool _wrap_range{}; }; template <typename T> @@ -268,12 +267,6 @@ struct compare<AbsoluteTolerance<U>> : public compare_base<AbsoluteTolerance<U>> using comparison_type = typename std::conditional<std::is_integral<U>::value, int64_t, U>::type; - if(this->_wrap_range) - { - const comparison_type abs_difference(std::abs(static_cast<comparison_type>(this->_target)) - std::abs(static_cast<comparison_type>(this->_reference))); - return abs_difference <= static_cast<comparison_type>(this->_tolerance); - } - const comparison_type abs_difference(std::abs(static_cast<comparison_type>(this->_target) - static_cast<comparison_type>(this->_reference))); return abs_difference <= static_cast<comparison_type>(this->_tolerance); @@ -323,7 +316,7 @@ void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, U toler validate(tensor, reference, shape_to_valid_region(tensor.shape()), tolerance_value, tolerance_number); } -template <typename T, typename U> +template <typename T, typename U, typename = typename std::enable_if<std::is_integral<T>::value>::type> void validate_wrap(const IAccessor &tensor, const SimpleTensor<T> &reference, U tolerance_value, float tolerance_number) { // Validate with valid region covering the entire shape @@ -391,7 +384,7 @@ void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const V } } -template <typename T, typename U> +template <typename T, typename U, typename = typename std::enable_if<std::is_integral<T>::value>::type> void validate_wrap(const IAccessor &tensor, const SimpleTensor<T> &reference, const ValidRegion &valid_region, U tolerance_value, float tolerance_number) { int64_t num_mismatches = 0; @@ -426,9 +419,23 @@ void validate_wrap(const IAccessor &tensor, const SimpleTensor<T> &reference, co bool equal = compare<U>(target_value, reference_value, tolerance_value); + // check for wrapping if(!equal) { - equal = compare<U>(target_value, reference_value, tolerance_value, true); + if(!support::cpp11::isfinite(target_value) || !support::cpp11::isfinite(reference_value)) + { + equal = false; + } + else + { + using limits_type = typename std::make_unsigned<T>::type; + + uint64_t max = std::numeric_limits<limits_type>::max(); + uint64_t abs_sum = std::abs(static_cast<int64_t>(target_value)) + std::abs(static_cast<int64_t>(reference_value)); + uint64_t wrap_difference = max - abs_sum; + + equal = wrap_difference < static_cast<uint64_t>(tolerance_value); + } } if(!equal) @@ -437,7 +444,7 @@ void validate_wrap(const IAccessor &tensor, const SimpleTensor<T> &reference, co ARM_COMPUTE_TEST_INFO("channel = " << c); ARM_COMPUTE_TEST_INFO("target = " << std::setprecision(5) << framework::make_printable(target_value)); ARM_COMPUTE_TEST_INFO("reference = " << std::setprecision(5) << framework::make_printable(reference_value)); - ARM_COMPUTE_TEST_INFO("tolerance = " << std::setprecision(5) << framework::make_printable(static_cast<typename U::value_type>(tolerance_value))); + ARM_COMPUTE_TEST_INFO("wrap_tolerance = " << std::setprecision(5) << framework::make_printable(static_cast<typename U::value_type>(tolerance_value))); framework::ARM_COMPUTE_PRINT_INFO(); ++num_mismatches; |