diff options
Diffstat (limited to 'tests/validation/Validation.h')
-rw-r--r-- | tests/validation/Validation.h | 87 |
1 files changed, 84 insertions, 3 deletions
diff --git a/tests/validation/Validation.h b/tests/validation/Validation.h index b6e7b8e82b..5e5dab0040 100644 --- a/tests/validation/Validation.h +++ b/tests/validation/Validation.h @@ -25,6 +25,7 @@ #define __ARM_COMPUTE_TEST_VALIDATION_H__ #include "arm_compute/core/FixedPoint.h" +#include "arm_compute/core/IArray.h" #include "arm_compute/core/Types.h" #include "tests/IAccessor.h" #include "tests/SimpleTensor.h" @@ -212,7 +213,11 @@ void validate(std::vector<unsigned int> classified_labels, std::vector<unsigned * - All values should match */ template <typename T, typename U = AbsoluteTolerance<T>> -void validate(T target, T reference, U tolerance = AbsoluteTolerance<T>()); +bool validate(T target, T reference, U tolerance = AbsoluteTolerance<T>()); + +/** Validate key points. */ +template <typename T, typename U, typename V = AbsoluteTolerance<float>> +void validate_keypoints(T target_first, T target_last, U reference_first, U reference_last, V tolerance = AbsoluteTolerance<float>()); template <typename T> struct compare_base @@ -358,13 +363,89 @@ void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const V } } +/** Check which keypoints from [first1, last1) are missing in [first2, last2) */ +template <typename T, typename U, typename V> +std::pair<int64_t, int64_t> compare_keypoints(T first1, T last1, U first2, U last2, V tolerance) +{ + int64_t num_missing = 0; + int64_t num_mismatches = 0; + + while(first1 != last1) + { + const auto point = std::find_if(first2, last2, [&](KeyPoint point) + { + return point.x == first1->x && point.y == first1->y; + }); + + if(point == last2) + { + ++num_missing; + ARM_COMPUTE_TEST_INFO("keypoint1 = " << *first1) + ARM_COMPUTE_EXPECT_FAIL("Key point not found", framework::LogLevel::DEBUG); + } + else if(!validate(point->tracking_status, first1->tracking_status) || !validate(point->strength, first1->strength, tolerance) || !validate(point->scale, first1->scale) + || !validate(point->orientation, first1->orientation) || !validate(point->error, first1->error)) + { + ++num_mismatches; + ARM_COMPUTE_TEST_INFO("keypoint1 = " << *first1) + ARM_COMPUTE_TEST_INFO("keypoint2 = " << *point) + ARM_COMPUTE_EXPECT_FAIL("Mismatching keypoint", framework::LogLevel::DEBUG); + } + + ++first1; + } + + return std::make_pair(num_missing, num_mismatches); +} + +template <typename T, typename U, typename V> +void validate_keypoints(T target_first, T target_last, U reference_first, U reference_last, V tolerance) +{ + const int64_t num_elements_target = std::distance(target_first, target_last); + const int64_t num_elements_reference = std::distance(reference_first, reference_last); + + ARM_COMPUTE_EXPECT_EQUAL(num_elements_target, num_elements_reference, framework::LogLevel::ERRORS); + + int64_t num_missing = 0; + int64_t num_mismatches = 0; + + if(num_elements_reference > 0) + { + std::tie(num_missing, num_mismatches) = compare_keypoints(reference_first, reference_last, target_first, target_last, tolerance); + + const float percent_missing = static_cast<float>(num_missing) / num_elements_reference * 100.f; + const float percent_mismatches = static_cast<float>(num_mismatches) / num_elements_reference * 100.f; + + ARM_COMPUTE_TEST_INFO(num_missing << " keypoints (" << std::fixed << std::setprecision(2) << percent_missing << "%) are missing in target"); + ARM_COMPUTE_EXPECT_EQUAL(num_missing, 0, framework::LogLevel::ERRORS); + + ARM_COMPUTE_TEST_INFO(num_mismatches << " keypoints (" << std::fixed << std::setprecision(2) << percent_mismatches << "%) mismatched"); + ARM_COMPUTE_EXPECT_EQUAL(num_mismatches, 0, framework::LogLevel::ERRORS); + } + + if(num_elements_target > 0) + { + std::tie(num_missing, num_mismatches) = compare_keypoints(target_first, target_last, reference_first, reference_last, tolerance); + + const float percent_missing = static_cast<float>(num_missing) / num_elements_target * 100.f; + + ARM_COMPUTE_TEST_INFO(num_missing << " keypoints (" << std::fixed << std::setprecision(2) << percent_missing << "%) are not part of target"); + ARM_COMPUTE_EXPECT_EQUAL(num_missing, 0, framework::LogLevel::ERRORS); + } +} + template <typename T, typename U> -void validate(T target, T reference, U tolerance) +bool validate(T target, T reference, U tolerance) { ARM_COMPUTE_TEST_INFO("reference = " << std::setprecision(5) << framework::make_printable(reference)); ARM_COMPUTE_TEST_INFO("target = " << std::setprecision(5) << framework::make_printable(target)); ARM_COMPUTE_TEST_INFO("tolerance = " << std::setprecision(5) << framework::make_printable(static_cast<typename U::value_type>(tolerance))); - ARM_COMPUTE_EXPECT((compare<U>(target, reference, tolerance)), framework::LogLevel::ERRORS); + + const bool equal = compare<U>(target, reference, tolerance); + + ARM_COMPUTE_EXPECT(equal, framework::LogLevel::ERRORS); + + return equal; } template <typename T, typename U> |