aboutsummaryrefslogtreecommitdiff
path: root/tests/validation_new/Validation.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation_new/Validation.h')
-rw-r--r--tests/validation_new/Validation.h174
1 files changed, 139 insertions, 35 deletions
diff --git a/tests/validation_new/Validation.h b/tests/validation_new/Validation.h
index 91b17145be..b21d12932a 100644
--- a/tests/validation_new/Validation.h
+++ b/tests/validation_new/Validation.h
@@ -43,6 +43,88 @@ namespace test
{
namespace validation
{
+/** Class reprensenting an absolute tolerance value. */
+template <typename T>
+class AbsoluteTolerance
+{
+public:
+ /** Underlying type. */
+ using value_type = T;
+
+ /* Default constructor.
+ *
+ * Initialises the tolerance to 0.
+ */
+ AbsoluteTolerance() = default;
+
+ /** Constructor.
+ *
+ * @param[in] value Absolute tolerance value.
+ */
+ explicit constexpr AbsoluteTolerance(T value)
+ : _value{ value }
+ {
+ }
+
+ /** Implicit conversion to the underlying type. */
+ constexpr operator T() const
+ {
+ return _value;
+ }
+
+private:
+ T _value{ std::numeric_limits<T>::epsilon() };
+};
+
+/** Class reprensenting a relative tolerance value. */
+class RelativeTolerance
+{
+public:
+ /** Underlying type. */
+ using value_type = double;
+
+ /* Default constructor.
+ *
+ * Initialises the tolerance to 0.
+ */
+ RelativeTolerance() = default;
+
+ /** Constructor.
+ *
+ * @param[in] value Relative tolerance value.
+ */
+ explicit constexpr RelativeTolerance(value_type value)
+ : _value{ value }
+ {
+ }
+
+ /** Implicit conversion to the underlying type. */
+ constexpr operator value_type() const
+ {
+ return _value;
+ }
+
+private:
+ value_type _value{ 0 };
+};
+
+/** Print AbsoluteTolerance type. */
+template <typename T>
+inline ::std::ostream &operator<<(::std::ostream &os, const AbsoluteTolerance<T> &tolerance)
+{
+ os << static_cast<typename AbsoluteTolerance<T>::value_type>(tolerance);
+
+ return os;
+}
+
+/** Print RelativeTolerance type. */
+inline ::std::ostream &operator<<(::std::ostream &os, const RelativeTolerance &tolerance)
+{
+ os << static_cast<typename RelativeTolerance::value_type>(tolerance);
+
+ return os;
+}
+
template <typename T>
bool compare_dimensions(const Dimensions<T> &dimensions1, const Dimensions<T> &dimensions2)
{
@@ -86,8 +168,8 @@ void validate(const arm_compute::PaddingSize &padding, const arm_compute::Paddin
* reference tensor and test tensor is multiple of wrap_range), but such errors would be detected by
* other test cases.
*/
-template <typename T, typename U = T>
-void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, U tolerance_value = U(0), float tolerance_number = 0.f);
+template <typename T, typename U = AbsoluteTolerance<T>>
+void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, U tolerance_value = U(), float tolerance_number = 0.f);
/** Validate tensors with valid region.
*
@@ -99,8 +181,8 @@ void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, U toler
* reference tensor and test tensor is multiple of wrap_range), but such errors would be detected by
* other test cases.
*/
-template <typename T, typename U = T>
-void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const ValidRegion &valid_region, U tolerance_value = U(0), float tolerance_number = 0.f);
+template <typename T, typename U = AbsoluteTolerance<T>>
+void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const ValidRegion &valid_region, U tolerance_value = U(), float tolerance_number = 0.f);
/** Validate tensors against constant value.
*
@@ -126,42 +208,66 @@ void validate(std::vector<unsigned int> classified_labels, std::vector<unsigned
*
* - All values should match
*/
-template <typename T, typename U = T>
-void validate(T target, T ref, U tolerance_abs_error = std::numeric_limits<T>::epsilon(), double tolerance_relative_error = 0.0001f);
+template <typename T, typename U>
+void validate(T target, T reference, U tolerance = AbsoluteTolerance<T>());
-template <typename T, typename U = T>
-bool is_equal(T target, T ref, U max_absolute_error = std::numeric_limits<T>::epsilon(), double max_relative_error = 0.0001f)
+template <typename T>
+struct compare_base
{
- if(!std::isfinite(target) || !std::isfinite(ref))
+ compare_base(typename T::value_type target, typename T::value_type reference, T tolerance = T(0))
+ : _target{ target }, _reference{ reference }, _tolerance{ tolerance }
{
- return false;
}
- // No need further check if they are equal
- if(ref == target)
- {
- return true;
- }
+ typename T::value_type _target{};
+ typename T::value_type _reference{};
+ T _tolerance{};
+};
- // Need this check for the situation when the two values close to zero but have different sign
- if(std::abs(std::abs(ref) - std::abs(target)) <= max_absolute_error)
- {
- return true;
- }
+template <typename T, typename U>
+struct compare;
- double relative_error = 0;
+template <typename U>
+struct compare<AbsoluteTolerance<U>, U> : public compare_base<AbsoluteTolerance<U>>
+{
+ using compare_base<AbsoluteTolerance<U>>::compare_base;
- if(std::abs(target) > std::abs(ref))
+ operator bool()
{
- relative_error = std::abs(static_cast<double>(target - ref) / target);
+ if(!std::isfinite(this->_target) || !std::isfinite(this->_reference))
+ {
+ return false;
+ }
+ else if(this->_target == this->_reference)
+ {
+ return true;
+ }
+
+ return static_cast<U>(std::abs(this->_target - this->_reference)) <= static_cast<U>(this->_tolerance);
}
- else
+};
+
+template <typename U>
+struct compare<RelativeTolerance, U> : public compare_base<RelativeTolerance>
+{
+ using compare_base<RelativeTolerance>::compare_base;
+
+ operator bool()
{
- relative_error = std::abs(static_cast<double>(ref - target) / ref);
- }
+ if(!std::isfinite(_target) || !std::isfinite(_reference))
+ {
+ return false;
+ }
+ else if(_target == _reference)
+ {
+ return true;
+ }
- return relative_error <= max_relative_error;
-}
+ const double relative_change = std::abs(static_cast<double>(_target - _reference)) / _reference;
+
+ return relative_change <= _tolerance;
+ }
+};
template <typename T, typename U>
void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, U tolerance_value, float tolerance_number)
@@ -198,7 +304,7 @@ void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const V
const T &target_value = reinterpret_cast<const T *>(tensor(id))[c];
const T &reference_value = reinterpret_cast<const T *>(reference(id))[c];
- if(!is_equal(target_value, reference_value, tolerance_value))
+ if(!compare<U, typename U::value_type>(target_value, reference_value, tolerance_value))
{
ARM_COMPUTE_TEST_INFO("id = " << id);
ARM_COMPUTE_TEST_INFO("channel = " << c);
@@ -227,14 +333,12 @@ void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const V
}
template <typename T, typename U>
-void validate(T target, T ref, U tolerance_abs_error, double tolerance_relative_error)
+void validate(T target, T reference, U tolerance)
{
- const bool equal = is_equal(target, ref, tolerance_abs_error, tolerance_relative_error);
-
- ARM_COMPUTE_TEST_INFO("reference = " << std::setprecision(5) << ref);
+ ARM_COMPUTE_TEST_INFO("reference = " << std::setprecision(5) << reference);
ARM_COMPUTE_TEST_INFO("target = " << std::setprecision(5) << target);
- ARM_COMPUTE_TEST_INFO("tolerance = " << std::setprecision(5) << tolerance_abs_error);
- ARM_COMPUTE_EXPECT(equal, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_TEST_INFO("tolerance = " << std::setprecision(5) << tolerance);
+ ARM_COMPUTE_EXPECT((compare<U, typename U::value_type>(target, reference, tolerance)), framework::LogLevel::ERRORS);
}
} // namespace validation
} // namespace test