diff options
Diffstat (limited to 'tests/validation/Validation.h')
-rw-r--r-- | tests/validation/Validation.h | 27 |
1 files changed, 19 insertions, 8 deletions
diff --git a/tests/validation/Validation.h b/tests/validation/Validation.h index f1ce0fecc7..289aca4d08 100644 --- a/tests/validation/Validation.h +++ b/tests/validation/Validation.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -45,6 +45,17 @@ namespace test { namespace validation { +namespace +{ +// Compare if 2 values are both infinities and if they are "equal" (has the same sign) +template <typename T> +inline bool are_equal_infs(T val0, T val1) +{ + const auto same_sign = support::cpp11::signbit(val0) == support::cpp11::signbit(val1); + return (!support::cpp11::isfinite(val0)) && (!support::cpp11::isfinite(val1)) && same_sign; +} +} // namespace + /** Class reprensenting an absolute tolerance value. */ template <typename T> class AbsoluteTolerance @@ -140,7 +151,7 @@ bool compare_dimensions(const Dimensions<T> &dimensions1, const Dimensions<T> &d { ARM_COMPUTE_ERROR_ON(data_layout == DataLayout::UNKNOWN); - if(data_layout == DataLayout::NCHW) + if(data_layout != DataLayout::NHWC) { if(dimensions1.num_dimensions() != dimensions2.num_dimensions()) { @@ -296,9 +307,9 @@ struct compare<AbsoluteTolerance<U>> : public compare_base<AbsoluteTolerance<U>> /** Perform comparison */ operator bool() const { - if(!support::cpp11::isfinite(this->_target) || !support::cpp11::isfinite(this->_reference)) + if(are_equal_infs(this->_target, this->_reference)) { - return false; + return true; } else if(this->_target == this->_reference) { @@ -322,9 +333,9 @@ struct compare<RelativeTolerance<U>> : public compare_base<RelativeTolerance<U>> /** Perform comparison */ operator bool() const { - if(!support::cpp11::isfinite(this->_target) || !support::cpp11::isfinite(this->_reference)) + if(are_equal_infs(this->_target, this->_reference)) { - return false; + return true; } else if(this->_target == this->_reference) { @@ -494,9 +505,9 @@ void validate_wrap(const IAccessor &tensor, const SimpleTensor<T> &reference, co // check for wrapping if(!equal) { - if(!support::cpp11::isfinite(target_value) || !support::cpp11::isfinite(reference_value)) + if(are_equal_infs(target_value, reference_value)) { - equal = false; + equal = true; } else { |