diff options
Diffstat (limited to 'tests/validation/Validation.h')
-rw-r--r-- | tests/validation/Validation.h | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/tests/validation/Validation.h b/tests/validation/Validation.h index e70c970cc1..6bc42a4ed6 100644 --- a/tests/validation/Validation.h +++ b/tests/validation/Validation.h @@ -226,11 +226,11 @@ struct compare_base T _tolerance{}; }; -template <typename T, typename U> +template <typename T> struct compare; template <typename U> -struct compare<AbsoluteTolerance<U>, U> : public compare_base<AbsoluteTolerance<U>> +struct compare<AbsoluteTolerance<U>> : public compare_base<AbsoluteTolerance<U>> { using compare_base<AbsoluteTolerance<U>>::compare_base; @@ -245,12 +245,16 @@ struct compare<AbsoluteTolerance<U>, U> : public compare_base<AbsoluteTolerance< return true; } - return static_cast<U>(std::abs(this->_target - this->_reference)) <= static_cast<U>(this->_tolerance); + using comparison_type = typename std::conditional<std::is_integral<U>::value, int64_t, U>::type; + + const comparison_type abs_difference(std::abs(static_cast<comparison_type>(this->_target) - static_cast<comparison_type>(this->_reference))); + + return abs_difference <= static_cast<comparison_type>(this->_tolerance); } }; template <typename U> -struct compare<RelativeTolerance<U>, U> : public compare_base<RelativeTolerance<U>> +struct compare<RelativeTolerance<U>> : public compare_base<RelativeTolerance<U>> { using compare_base<RelativeTolerance<U>>::compare_base; @@ -325,7 +329,7 @@ void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const V const T &target_value = reinterpret_cast<const T *>(tensor(id))[c]; const T &reference_value = reinterpret_cast<const T *>(reference(id))[c]; - if(!compare<U, typename U::value_type>(target_value, reference_value, tolerance_value)) + if(!compare<U>(target_value, reference_value, tolerance_value)) { ARM_COMPUTE_TEST_INFO("id = " << id); ARM_COMPUTE_TEST_INFO("channel = " << c); @@ -359,7 +363,7 @@ void validate(T target, T reference, U tolerance) ARM_COMPUTE_TEST_INFO("reference = " << std::setprecision(5) << framework::make_printable(reference)); ARM_COMPUTE_TEST_INFO("target = " << std::setprecision(5) << framework::make_printable(target)); ARM_COMPUTE_TEST_INFO("tolerance = " << std::setprecision(5) << framework::make_printable(static_cast<typename U::value_type>(tolerance))); - ARM_COMPUTE_EXPECT((compare<U, typename U::value_type>(target, reference, tolerance)), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT((compare<U>(target, reference, tolerance)), framework::LogLevel::ERRORS); } } // namespace validation } // namespace test |