diff options
Diffstat (limited to 'tests/validation_new/Validation.h')
-rw-r--r-- | tests/validation_new/Validation.h | 174 |
1 files changed, 139 insertions, 35 deletions
diff --git a/tests/validation_new/Validation.h b/tests/validation_new/Validation.h index 91b17145be..b21d12932a 100644 --- a/tests/validation_new/Validation.h +++ b/tests/validation_new/Validation.h @@ -43,6 +43,88 @@ namespace test { namespace validation { +/** Class reprensenting an absolute tolerance value. */ +template <typename T> +class AbsoluteTolerance +{ +public: + /** Underlying type. */ + using value_type = T; + + /* Default constructor. + * + * Initialises the tolerance to 0. + */ + AbsoluteTolerance() = default; + + /** Constructor. + * + * @param[in] value Absolute tolerance value. + */ + explicit constexpr AbsoluteTolerance(T value) + : _value{ value } + { + } + + /** Implicit conversion to the underlying type. */ + constexpr operator T() const + { + return _value; + } + +private: + T _value{ std::numeric_limits<T>::epsilon() }; +}; + +/** Class reprensenting a relative tolerance value. */ +class RelativeTolerance +{ +public: + /** Underlying type. */ + using value_type = double; + + /* Default constructor. + * + * Initialises the tolerance to 0. + */ + RelativeTolerance() = default; + + /** Constructor. + * + * @param[in] value Relative tolerance value. + */ + explicit constexpr RelativeTolerance(value_type value) + : _value{ value } + { + } + + /** Implicit conversion to the underlying type. */ + constexpr operator value_type() const + { + return _value; + } + +private: + value_type _value{ 0 }; +}; + +/** Print AbsoluteTolerance type. */ +template <typename T> +inline ::std::ostream &operator<<(::std::ostream &os, const AbsoluteTolerance<T> &tolerance) +{ + os << static_cast<typename AbsoluteTolerance<T>::value_type>(tolerance); + + return os; +} + +/** Print RelativeTolerance type. */ +inline ::std::ostream &operator<<(::std::ostream &os, const RelativeTolerance &tolerance) +{ + os << static_cast<typename RelativeTolerance::value_type>(tolerance); + + return os; +} + template <typename T> bool compare_dimensions(const Dimensions<T> &dimensions1, const Dimensions<T> &dimensions2) { @@ -86,8 +168,8 @@ void validate(const arm_compute::PaddingSize &padding, const arm_compute::Paddin * reference tensor and test tensor is multiple of wrap_range), but such errors would be detected by * other test cases. */ -template <typename T, typename U = T> -void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, U tolerance_value = U(0), float tolerance_number = 0.f); +template <typename T, typename U = AbsoluteTolerance<T>> +void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, U tolerance_value = U(), float tolerance_number = 0.f); /** Validate tensors with valid region. * @@ -99,8 +181,8 @@ void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, U toler * reference tensor and test tensor is multiple of wrap_range), but such errors would be detected by * other test cases. */ -template <typename T, typename U = T> -void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const ValidRegion &valid_region, U tolerance_value = U(0), float tolerance_number = 0.f); +template <typename T, typename U = AbsoluteTolerance<T>> +void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const ValidRegion &valid_region, U tolerance_value = U(), float tolerance_number = 0.f); /** Validate tensors against constant value. * @@ -126,42 +208,66 @@ void validate(std::vector<unsigned int> classified_labels, std::vector<unsigned * * - All values should match */ -template <typename T, typename U = T> -void validate(T target, T ref, U tolerance_abs_error = std::numeric_limits<T>::epsilon(), double tolerance_relative_error = 0.0001f); +template <typename T, typename U> +void validate(T target, T reference, U tolerance = AbsoluteTolerance<T>()); -template <typename T, typename U = T> -bool is_equal(T target, T ref, U max_absolute_error = std::numeric_limits<T>::epsilon(), double max_relative_error = 0.0001f) +template <typename T> +struct compare_base { - if(!std::isfinite(target) || !std::isfinite(ref)) + compare_base(typename T::value_type target, typename T::value_type reference, T tolerance = T(0)) + : _target{ target }, _reference{ reference }, _tolerance{ tolerance } { - return false; } - // No need further check if they are equal - if(ref == target) - { - return true; - } + typename T::value_type _target{}; + typename T::value_type _reference{}; + T _tolerance{}; +}; - // Need this check for the situation when the two values close to zero but have different sign - if(std::abs(std::abs(ref) - std::abs(target)) <= max_absolute_error) - { - return true; - } +template <typename T, typename U> +struct compare; - double relative_error = 0; +template <typename U> +struct compare<AbsoluteTolerance<U>, U> : public compare_base<AbsoluteTolerance<U>> +{ + using compare_base<AbsoluteTolerance<U>>::compare_base; - if(std::abs(target) > std::abs(ref)) + operator bool() { - relative_error = std::abs(static_cast<double>(target - ref) / target); + if(!std::isfinite(this->_target) || !std::isfinite(this->_reference)) + { + return false; + } + else if(this->_target == this->_reference) + { + return true; + } + + return static_cast<U>(std::abs(this->_target - this->_reference)) <= static_cast<U>(this->_tolerance); } - else +}; + +template <typename U> +struct compare<RelativeTolerance, U> : public compare_base<RelativeTolerance> +{ + using compare_base<RelativeTolerance>::compare_base; + + operator bool() { - relative_error = std::abs(static_cast<double>(ref - target) / ref); - } + if(!std::isfinite(_target) || !std::isfinite(_reference)) + { + return false; + } + else if(_target == _reference) + { + return true; + } - return relative_error <= max_relative_error; -} + const double relative_change = std::abs(static_cast<double>(_target - _reference)) / _reference; + + return relative_change <= _tolerance; + } +}; template <typename T, typename U> void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, U tolerance_value, float tolerance_number) @@ -198,7 +304,7 @@ void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const V const T &target_value = reinterpret_cast<const T *>(tensor(id))[c]; const T &reference_value = reinterpret_cast<const T *>(reference(id))[c]; - if(!is_equal(target_value, reference_value, tolerance_value)) + if(!compare<U, typename U::value_type>(target_value, reference_value, tolerance_value)) { ARM_COMPUTE_TEST_INFO("id = " << id); ARM_COMPUTE_TEST_INFO("channel = " << c); @@ -227,14 +333,12 @@ void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const V } template <typename T, typename U> -void validate(T target, T ref, U tolerance_abs_error, double tolerance_relative_error) +void validate(T target, T reference, U tolerance) { - const bool equal = is_equal(target, ref, tolerance_abs_error, tolerance_relative_error); - - ARM_COMPUTE_TEST_INFO("reference = " << std::setprecision(5) << ref); + ARM_COMPUTE_TEST_INFO("reference = " << std::setprecision(5) << reference); ARM_COMPUTE_TEST_INFO("target = " << std::setprecision(5) << target); - ARM_COMPUTE_TEST_INFO("tolerance = " << std::setprecision(5) << tolerance_abs_error); - ARM_COMPUTE_EXPECT(equal, framework::LogLevel::ERRORS); + ARM_COMPUTE_TEST_INFO("tolerance = " << std::setprecision(5) << tolerance); + ARM_COMPUTE_EXPECT((compare<U, typename U::value_type>(target, reference, tolerance)), framework::LogLevel::ERRORS); } } // namespace validation } // namespace test |