From 6106a4de410e7cc59515dd889e159bee7aa45d35 Mon Sep 17 00:00:00 2001 From: Moritz Pflanzer Date: Wed, 2 Aug 2017 09:42:27 +0100 Subject: COMPMID-415: Use absolute and relative tolerance Change-Id: Ib779fa307e05fa67172ddaf521239b4c746debc8 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/82229 Reviewed-by: Anthony Barbier Tested-by: Kaizen --- tests/validation_new/Validation.h | 174 ++++++++++++++++++++++++++++++-------- 1 file changed, 139 insertions(+), 35 deletions(-) (limited to 'tests/validation_new/Validation.h') diff --git a/tests/validation_new/Validation.h b/tests/validation_new/Validation.h index 91b17145be..b21d12932a 100644 --- a/tests/validation_new/Validation.h +++ b/tests/validation_new/Validation.h @@ -43,6 +43,88 @@ namespace test { namespace validation { +/** Class reprensenting an absolute tolerance value. */ +template +class AbsoluteTolerance +{ +public: + /** Underlying type. */ + using value_type = T; + + /* Default constructor. + * + * Initialises the tolerance to 0. + */ + AbsoluteTolerance() = default; + + /** Constructor. + * + * @param[in] value Absolute tolerance value. + */ + explicit constexpr AbsoluteTolerance(T value) + : _value{ value } + { + } + + /** Implicit conversion to the underlying type. */ + constexpr operator T() const + { + return _value; + } + +private: + T _value{ std::numeric_limits::epsilon() }; +}; + +/** Class reprensenting a relative tolerance value. */ +class RelativeTolerance +{ +public: + /** Underlying type. */ + using value_type = double; + + /* Default constructor. + * + * Initialises the tolerance to 0. + */ + RelativeTolerance() = default; + + /** Constructor. + * + * @param[in] value Relative tolerance value. + */ + explicit constexpr RelativeTolerance(value_type value) + : _value{ value } + { + } + + /** Implicit conversion to the underlying type. */ + constexpr operator value_type() const + { + return _value; + } + +private: + value_type _value{ 0 }; +}; + +/** Print AbsoluteTolerance type. */ +template +inline ::std::ostream &operator<<(::std::ostream &os, const AbsoluteTolerance &tolerance) +{ + os << static_cast::value_type>(tolerance); + + return os; +} + +/** Print RelativeTolerance type. */ +inline ::std::ostream &operator<<(::std::ostream &os, const RelativeTolerance &tolerance) +{ + os << static_cast(tolerance); + + return os; +} + template bool compare_dimensions(const Dimensions &dimensions1, const Dimensions &dimensions2) { @@ -86,8 +168,8 @@ void validate(const arm_compute::PaddingSize &padding, const arm_compute::Paddin * reference tensor and test tensor is multiple of wrap_range), but such errors would be detected by * other test cases. */ -template -void validate(const IAccessor &tensor, const SimpleTensor &reference, U tolerance_value = U(0), float tolerance_number = 0.f); +template > +void validate(const IAccessor &tensor, const SimpleTensor &reference, U tolerance_value = U(), float tolerance_number = 0.f); /** Validate tensors with valid region. * @@ -99,8 +181,8 @@ void validate(const IAccessor &tensor, const SimpleTensor &reference, U toler * reference tensor and test tensor is multiple of wrap_range), but such errors would be detected by * other test cases. */ -template -void validate(const IAccessor &tensor, const SimpleTensor &reference, const ValidRegion &valid_region, U tolerance_value = U(0), float tolerance_number = 0.f); +template > +void validate(const IAccessor &tensor, const SimpleTensor &reference, const ValidRegion &valid_region, U tolerance_value = U(), float tolerance_number = 0.f); /** Validate tensors against constant value. * @@ -126,42 +208,66 @@ void validate(std::vector classified_labels, std::vector -void validate(T target, T ref, U tolerance_abs_error = std::numeric_limits::epsilon(), double tolerance_relative_error = 0.0001f); +template +void validate(T target, T reference, U tolerance = AbsoluteTolerance()); -template -bool is_equal(T target, T ref, U max_absolute_error = std::numeric_limits::epsilon(), double max_relative_error = 0.0001f) +template +struct compare_base { - if(!std::isfinite(target) || !std::isfinite(ref)) + compare_base(typename T::value_type target, typename T::value_type reference, T tolerance = T(0)) + : _target{ target }, _reference{ reference }, _tolerance{ tolerance } { - return false; } - // No need further check if they are equal - if(ref == target) - { - return true; - } + typename T::value_type _target{}; + typename T::value_type _reference{}; + T _tolerance{}; +}; - // Need this check for the situation when the two values close to zero but have different sign - if(std::abs(std::abs(ref) - std::abs(target)) <= max_absolute_error) - { - return true; - } +template +struct compare; - double relative_error = 0; +template +struct compare, U> : public compare_base> +{ + using compare_base>::compare_base; - if(std::abs(target) > std::abs(ref)) + operator bool() { - relative_error = std::abs(static_cast(target - ref) / target); + if(!std::isfinite(this->_target) || !std::isfinite(this->_reference)) + { + return false; + } + else if(this->_target == this->_reference) + { + return true; + } + + return static_cast(std::abs(this->_target - this->_reference)) <= static_cast(this->_tolerance); } - else +}; + +template +struct compare : public compare_base +{ + using compare_base::compare_base; + + operator bool() { - relative_error = std::abs(static_cast(ref - target) / ref); - } + if(!std::isfinite(_target) || !std::isfinite(_reference)) + { + return false; + } + else if(_target == _reference) + { + return true; + } - return relative_error <= max_relative_error; -} + const double relative_change = std::abs(static_cast(_target - _reference)) / _reference; + + return relative_change <= _tolerance; + } +}; template void validate(const IAccessor &tensor, const SimpleTensor &reference, U tolerance_value, float tolerance_number) @@ -198,7 +304,7 @@ void validate(const IAccessor &tensor, const SimpleTensor &reference, const V const T &target_value = reinterpret_cast(tensor(id))[c]; const T &reference_value = reinterpret_cast(reference(id))[c]; - if(!is_equal(target_value, reference_value, tolerance_value)) + if(!compare(target_value, reference_value, tolerance_value)) { ARM_COMPUTE_TEST_INFO("id = " << id); ARM_COMPUTE_TEST_INFO("channel = " << c); @@ -227,14 +333,12 @@ void validate(const IAccessor &tensor, const SimpleTensor &reference, const V } template -void validate(T target, T ref, U tolerance_abs_error, double tolerance_relative_error) +void validate(T target, T reference, U tolerance) { - const bool equal = is_equal(target, ref, tolerance_abs_error, tolerance_relative_error); - - ARM_COMPUTE_TEST_INFO("reference = " << std::setprecision(5) << ref); + ARM_COMPUTE_TEST_INFO("reference = " << std::setprecision(5) << reference); ARM_COMPUTE_TEST_INFO("target = " << std::setprecision(5) << target); - ARM_COMPUTE_TEST_INFO("tolerance = " << std::setprecision(5) << tolerance_abs_error); - ARM_COMPUTE_EXPECT(equal, framework::LogLevel::ERRORS); + ARM_COMPUTE_TEST_INFO("tolerance = " << std::setprecision(5) << tolerance); + ARM_COMPUTE_EXPECT((compare(target, reference, tolerance)), framework::LogLevel::ERRORS); } } // namespace validation } // namespace test -- cgit v1.2.1