diff options
author | Moritz Pflanzer <moritz.pflanzer@arm.com> | 2017-09-24 12:09:41 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:35:24 +0000 |
commit | 6c6597c1e17c32c6ad861780eee454a7deecfb75 (patch) | |
tree | 5df015557262a83e5e84a5fa365544bb1aa66762 /tests/validation/Validation.h | |
parent | c26ecf8ca13205cab2ce43d9f971e1569808e5bc (diff) | |
download | ComputeLibrary-6c6597c1e17c32c6ad861780eee454a7deecfb75.tar.gz |
COMPMID-500: Move HarrisCorners to new validation
Change-Id: I4e21ad98d029e360010c5927f04b716527700a00
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/88888
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Diffstat (limited to 'tests/validation/Validation.h')
-rw-r--r-- | tests/validation/Validation.h | 87 |
1 files changed, 84 insertions, 3 deletions
diff --git a/tests/validation/Validation.h b/tests/validation/Validation.h index b6e7b8e82b..5e5dab0040 100644 --- a/tests/validation/Validation.h +++ b/tests/validation/Validation.h @@ -25,6 +25,7 @@ #define __ARM_COMPUTE_TEST_VALIDATION_H__ #include "arm_compute/core/FixedPoint.h" +#include "arm_compute/core/IArray.h" #include "arm_compute/core/Types.h" #include "tests/IAccessor.h" #include "tests/SimpleTensor.h" @@ -212,7 +213,11 @@ void validate(std::vector<unsigned int> classified_labels, std::vector<unsigned * - All values should match */ template <typename T, typename U = AbsoluteTolerance<T>> -void validate(T target, T reference, U tolerance = AbsoluteTolerance<T>()); +bool validate(T target, T reference, U tolerance = AbsoluteTolerance<T>()); + +/** Validate key points. */ +template <typename T, typename U, typename V = AbsoluteTolerance<float>> +void validate_keypoints(T target_first, T target_last, U reference_first, U reference_last, V tolerance = AbsoluteTolerance<float>()); template <typename T> struct compare_base @@ -358,13 +363,89 @@ void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const V } } +/** Check which keypoints from [first1, last1) are missing in [first2, last2) */ +template <typename T, typename U, typename V> +std::pair<int64_t, int64_t> compare_keypoints(T first1, T last1, U first2, U last2, V tolerance) +{ + int64_t num_missing = 0; + int64_t num_mismatches = 0; + + while(first1 != last1) + { + const auto point = std::find_if(first2, last2, [&](KeyPoint point) + { + return point.x == first1->x && point.y == first1->y; + }); + + if(point == last2) + { + ++num_missing; + ARM_COMPUTE_TEST_INFO("keypoint1 = " << *first1) + ARM_COMPUTE_EXPECT_FAIL("Key point not found", framework::LogLevel::DEBUG); + } + else if(!validate(point->tracking_status, first1->tracking_status) || !validate(point->strength, first1->strength, tolerance) || !validate(point->scale, first1->scale) + || !validate(point->orientation, first1->orientation) || !validate(point->error, first1->error)) + { + ++num_mismatches; + ARM_COMPUTE_TEST_INFO("keypoint1 = " << *first1) + ARM_COMPUTE_TEST_INFO("keypoint2 = " << *point) + ARM_COMPUTE_EXPECT_FAIL("Mismatching keypoint", framework::LogLevel::DEBUG); + } + + ++first1; + } + + return std::make_pair(num_missing, num_mismatches); +} + +template <typename T, typename U, typename V> +void validate_keypoints(T target_first, T target_last, U reference_first, U reference_last, V tolerance) +{ + const int64_t num_elements_target = std::distance(target_first, target_last); + const int64_t num_elements_reference = std::distance(reference_first, reference_last); + + ARM_COMPUTE_EXPECT_EQUAL(num_elements_target, num_elements_reference, framework::LogLevel::ERRORS); + + int64_t num_missing = 0; + int64_t num_mismatches = 0; + + if(num_elements_reference > 0) + { + std::tie(num_missing, num_mismatches) = compare_keypoints(reference_first, reference_last, target_first, target_last, tolerance); + + const float percent_missing = static_cast<float>(num_missing) / num_elements_reference * 100.f; + const float percent_mismatches = static_cast<float>(num_mismatches) / num_elements_reference * 100.f; + + ARM_COMPUTE_TEST_INFO(num_missing << " keypoints (" << std::fixed << std::setprecision(2) << percent_missing << "%) are missing in target"); + ARM_COMPUTE_EXPECT_EQUAL(num_missing, 0, framework::LogLevel::ERRORS); + + ARM_COMPUTE_TEST_INFO(num_mismatches << " keypoints (" << std::fixed << std::setprecision(2) << percent_mismatches << "%) mismatched"); + ARM_COMPUTE_EXPECT_EQUAL(num_mismatches, 0, framework::LogLevel::ERRORS); + } + + if(num_elements_target > 0) + { + std::tie(num_missing, num_mismatches) = compare_keypoints(target_first, target_last, reference_first, reference_last, tolerance); + + const float percent_missing = static_cast<float>(num_missing) / num_elements_target * 100.f; + + ARM_COMPUTE_TEST_INFO(num_missing << " keypoints (" << std::fixed << std::setprecision(2) << percent_missing << "%) are not part of target"); + ARM_COMPUTE_EXPECT_EQUAL(num_missing, 0, framework::LogLevel::ERRORS); + } +} + template <typename T, typename U> -void validate(T target, T reference, U tolerance) +bool validate(T target, T reference, U tolerance) { ARM_COMPUTE_TEST_INFO("reference = " << std::setprecision(5) << framework::make_printable(reference)); ARM_COMPUTE_TEST_INFO("target = " << std::setprecision(5) << framework::make_printable(target)); ARM_COMPUTE_TEST_INFO("tolerance = " << std::setprecision(5) << framework::make_printable(static_cast<typename U::value_type>(tolerance))); - ARM_COMPUTE_EXPECT((compare<U>(target, reference, tolerance)), framework::LogLevel::ERRORS); + + const bool equal = compare<U>(target, reference, tolerance); + + ARM_COMPUTE_EXPECT(equal, framework::LogLevel::ERRORS); + + return equal; } template <typename T, typename U> |