aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/Validation.h
diff options
context:
space:
mode:
authorMoritz Pflanzer <moritz.pflanzer@arm.com>2017-09-24 12:09:41 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:35:24 +0000
commit6c6597c1e17c32c6ad861780eee454a7deecfb75 (patch)
tree5df015557262a83e5e84a5fa365544bb1aa66762 /tests/validation/Validation.h
parentc26ecf8ca13205cab2ce43d9f971e1569808e5bc (diff)
downloadComputeLibrary-6c6597c1e17c32c6ad861780eee454a7deecfb75.tar.gz
COMPMID-500: Move HarrisCorners to new validation
Change-Id: I4e21ad98d029e360010c5927f04b716527700a00 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/88888 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Diffstat (limited to 'tests/validation/Validation.h')
-rw-r--r--tests/validation/Validation.h87
1 files changed, 84 insertions, 3 deletions
diff --git a/tests/validation/Validation.h b/tests/validation/Validation.h
index b6e7b8e82b..5e5dab0040 100644
--- a/tests/validation/Validation.h
+++ b/tests/validation/Validation.h
@@ -25,6 +25,7 @@
#define __ARM_COMPUTE_TEST_VALIDATION_H__
#include "arm_compute/core/FixedPoint.h"
+#include "arm_compute/core/IArray.h"
#include "arm_compute/core/Types.h"
#include "tests/IAccessor.h"
#include "tests/SimpleTensor.h"
@@ -212,7 +213,11 @@ void validate(std::vector<unsigned int> classified_labels, std::vector<unsigned
* - All values should match
*/
template <typename T, typename U = AbsoluteTolerance<T>>
-void validate(T target, T reference, U tolerance = AbsoluteTolerance<T>());
+bool validate(T target, T reference, U tolerance = AbsoluteTolerance<T>());
+
+/** Validate key points. */
+template <typename T, typename U, typename V = AbsoluteTolerance<float>>
+void validate_keypoints(T target_first, T target_last, U reference_first, U reference_last, V tolerance = AbsoluteTolerance<float>());
template <typename T>
struct compare_base
@@ -358,13 +363,89 @@ void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const V
}
}
+/** Check which keypoints from [first1, last1) are missing in [first2, last2) */
+template <typename T, typename U, typename V>
+std::pair<int64_t, int64_t> compare_keypoints(T first1, T last1, U first2, U last2, V tolerance)
+{
+ int64_t num_missing = 0;
+ int64_t num_mismatches = 0;
+
+ while(first1 != last1)
+ {
+ const auto point = std::find_if(first2, last2, [&](KeyPoint point)
+ {
+ return point.x == first1->x && point.y == first1->y;
+ });
+
+ if(point == last2)
+ {
+ ++num_missing;
+ ARM_COMPUTE_TEST_INFO("keypoint1 = " << *first1)
+ ARM_COMPUTE_EXPECT_FAIL("Key point not found", framework::LogLevel::DEBUG);
+ }
+ else if(!validate(point->tracking_status, first1->tracking_status) || !validate(point->strength, first1->strength, tolerance) || !validate(point->scale, first1->scale)
+ || !validate(point->orientation, first1->orientation) || !validate(point->error, first1->error))
+ {
+ ++num_mismatches;
+ ARM_COMPUTE_TEST_INFO("keypoint1 = " << *first1)
+ ARM_COMPUTE_TEST_INFO("keypoint2 = " << *point)
+ ARM_COMPUTE_EXPECT_FAIL("Mismatching keypoint", framework::LogLevel::DEBUG);
+ }
+
+ ++first1;
+ }
+
+ return std::make_pair(num_missing, num_mismatches);
+}
+
+template <typename T, typename U, typename V>
+void validate_keypoints(T target_first, T target_last, U reference_first, U reference_last, V tolerance)
+{
+ const int64_t num_elements_target = std::distance(target_first, target_last);
+ const int64_t num_elements_reference = std::distance(reference_first, reference_last);
+
+ ARM_COMPUTE_EXPECT_EQUAL(num_elements_target, num_elements_reference, framework::LogLevel::ERRORS);
+
+ int64_t num_missing = 0;
+ int64_t num_mismatches = 0;
+
+ if(num_elements_reference > 0)
+ {
+ std::tie(num_missing, num_mismatches) = compare_keypoints(reference_first, reference_last, target_first, target_last, tolerance);
+
+ const float percent_missing = static_cast<float>(num_missing) / num_elements_reference * 100.f;
+ const float percent_mismatches = static_cast<float>(num_mismatches) / num_elements_reference * 100.f;
+
+ ARM_COMPUTE_TEST_INFO(num_missing << " keypoints (" << std::fixed << std::setprecision(2) << percent_missing << "%) are missing in target");
+ ARM_COMPUTE_EXPECT_EQUAL(num_missing, 0, framework::LogLevel::ERRORS);
+
+ ARM_COMPUTE_TEST_INFO(num_mismatches << " keypoints (" << std::fixed << std::setprecision(2) << percent_mismatches << "%) mismatched");
+ ARM_COMPUTE_EXPECT_EQUAL(num_mismatches, 0, framework::LogLevel::ERRORS);
+ }
+
+ if(num_elements_target > 0)
+ {
+ std::tie(num_missing, num_mismatches) = compare_keypoints(target_first, target_last, reference_first, reference_last, tolerance);
+
+ const float percent_missing = static_cast<float>(num_missing) / num_elements_target * 100.f;
+
+ ARM_COMPUTE_TEST_INFO(num_missing << " keypoints (" << std::fixed << std::setprecision(2) << percent_missing << "%) are not part of target");
+ ARM_COMPUTE_EXPECT_EQUAL(num_missing, 0, framework::LogLevel::ERRORS);
+ }
+}
+
template <typename T, typename U>
-void validate(T target, T reference, U tolerance)
+bool validate(T target, T reference, U tolerance)
{
ARM_COMPUTE_TEST_INFO("reference = " << std::setprecision(5) << framework::make_printable(reference));
ARM_COMPUTE_TEST_INFO("target = " << std::setprecision(5) << framework::make_printable(target));
ARM_COMPUTE_TEST_INFO("tolerance = " << std::setprecision(5) << framework::make_printable(static_cast<typename U::value_type>(tolerance)));
- ARM_COMPUTE_EXPECT((compare<U>(target, reference, tolerance)), framework::LogLevel::ERRORS);
+
+ const bool equal = compare<U>(target, reference, tolerance);
+
+ ARM_COMPUTE_EXPECT(equal, framework::LogLevel::ERRORS);
+
+ return equal;
}
template <typename T, typename U>