aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/Validation.h
diff options
context:
space:
mode:
authorsteniu01 <steven.niu@arm.com>2017-08-25 17:18:01 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:35:24 +0000
commit3e05e4e85912e745b8555102e1bcef13478d2ceb (patch)
treed4ac2e56bcbcbb2ca73b990deeeb26aa2fa1f73d /tests/validation/Validation.h
parent09e4f98e31a9bb77bebeccd59c70f0715ab2c292 (diff)
downloadComputeLibrary-3e05e4e85912e745b8555102e1bcef13478d2ceb.tar.gz
COMPMID-516 Change the CL CNN validation functions to use relative
tolerance error Change-Id: Iec6347af26ea2a83c911f5fe10e6048e8a2a47ba Reviewed-on: http://mpd-gerrit.cambridge.arm.com/85381 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Moritz Pflanzer <moritz.pflanzer@arm.com>
Diffstat (limited to 'tests/validation/Validation.h')
-rw-r--r--tests/validation/Validation.h36
1 files changed, 26 insertions, 10 deletions
diff --git a/tests/validation/Validation.h b/tests/validation/Validation.h
index 49c7d832c1..e70c970cc1 100644
--- a/tests/validation/Validation.h
+++ b/tests/validation/Validation.h
@@ -77,11 +77,12 @@ private:
};
/** Class reprensenting a relative tolerance value. */
+template <typename T>
class RelativeTolerance
{
public:
/** Underlying type. */
- using value_type = double;
+ using value_type = T;
/* Default constructor.
*
@@ -105,7 +106,7 @@ public:
}
private:
- value_type _value{ 0 };
+ value_type _value{ std::numeric_limits<T>::epsilon() };
};
/** Print AbsoluteTolerance type. */
@@ -118,9 +119,10 @@ inline ::std::ostream &operator<<(::std::ostream &os, const AbsoluteTolerance<T>
}
/** Print RelativeTolerance type. */
-inline ::std::ostream &operator<<(::std::ostream &os, const RelativeTolerance &tolerance)
+template <typename T>
+inline ::std::ostream &operator<<(::std::ostream &os, const RelativeTolerance<T> &tolerance)
{
- os << static_cast<typename RelativeTolerance::value_type>(tolerance);
+ os << static_cast<typename RelativeTolerance<T>::value_type>(tolerance);
return os;
}
@@ -248,24 +250,38 @@ struct compare<AbsoluteTolerance<U>, U> : public compare_base<AbsoluteTolerance<
};
template <typename U>
-struct compare<RelativeTolerance, U> : public compare_base<RelativeTolerance>
+struct compare<RelativeTolerance<U>, U> : public compare_base<RelativeTolerance<U>>
{
- using compare_base<RelativeTolerance>::compare_base;
+ using compare_base<RelativeTolerance<U>>::compare_base;
operator bool() const
{
- if(!std::isfinite(_target) || !std::isfinite(_reference))
+ if(!std::isfinite(this->_target) || !std::isfinite(this->_reference))
{
return false;
}
- else if(_target == _reference)
+ else if(this->_target == this->_reference)
+ {
+ return true;
+ }
+
+ const U epsilon = (std::is_same<half_float::half, typename std::remove_cv<U>::type>::value || (this->_reference == 0)) ? static_cast<U>(0.01) : std::numeric_limits<U>::epsilon();
+
+ if(std::abs(static_cast<double>(this->_reference) - static_cast<double>(this->_target)) <= epsilon)
{
return true;
}
+ else
+ {
+ if(static_cast<double>(this->_reference) == 0.0f) // We have checked whether _reference and _target is closing. If _reference is 0 but not closed to _target, it should return false
+ {
+ return false;
+ }
- const double relative_change = std::abs(static_cast<double>(_target - _reference)) / _reference;
+ const double relative_change = std::abs(static_cast<double>(this->_target) - static_cast<double>(this->_reference)) / this->_reference;
- return relative_change <= _tolerance;
+ return relative_change <= static_cast<U>(this->_tolerance);
+ }
}
};