aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/Validation.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/Validation.cpp')
-rw-r--r--tests/validation/Validation.cpp34
1 files changed, 33 insertions, 1 deletions
diff --git a/tests/validation/Validation.cpp b/tests/validation/Validation.cpp
index eac4105b21..868bbaac5e 100644
--- a/tests/validation/Validation.cpp
+++ b/tests/validation/Validation.cpp
@@ -193,7 +193,6 @@ void check_single_element(const Coordinates &id, const IAccessor &tensor, const
BOOST_TEST_INFO("reference = " << std::setprecision(5) << ref);
BOOST_TEST_INFO("target = " << std::setprecision(5) << target);
BOOST_TEST_WARN(equal);
-
++num_mismatches;
}
++num_elements;
@@ -264,6 +263,39 @@ void validate(const IAccessor &tensor, const RawTensor &reference, const ValidRe
<< "%) mismatched (maximum tolerated " << std::setprecision(2) << tolerance_number << "%)");
}
+void validate(const IAccessor &tensor, const RawTensor &reference, const RawTensor &valid_mask, float tolerance_value, float tolerance_number, uint64_t wrap_range)
+{
+ int64_t num_mismatches = 0;
+ int64_t num_elements = 0;
+
+ BOOST_TEST(tensor.element_size() == reference.element_size());
+ BOOST_TEST(tensor.format() == reference.format());
+ BOOST_TEST(tensor.data_type() == reference.data_type());
+ BOOST_TEST(tensor.num_channels() == reference.num_channels());
+ BOOST_TEST(compare_dimensions(tensor.shape(), reference.shape()));
+
+ const int min_elements = std::min(tensor.num_elements(), reference.num_elements());
+ const int min_channels = std::min(tensor.num_channels(), reference.num_channels());
+ const size_t channel_size = element_size_from_data_type(reference.data_type());
+
+ // Iterate over all elements within valid region, e.g. U8, S16, RGB888, ...
+ for(int element_idx = 0; element_idx < min_elements; ++element_idx)
+ {
+ const Coordinates id = index2coord(reference.shape(), element_idx);
+ if(valid_mask[element_idx] == 1)
+ {
+ check_single_element(id, tensor, reference, tolerance_value, wrap_range, min_channels, channel_size, num_mismatches, num_elements);
+ }
+ }
+
+ const int64_t absolute_tolerance_number = tolerance_number * num_elements;
+ const float percent_mismatches = static_cast<float>(num_mismatches) / num_elements * 100.f;
+
+ BOOST_TEST(num_mismatches <= absolute_tolerance_number,
+ num_mismatches << " values (" << std::setprecision(2) << percent_mismatches
+ << "%) mismatched (maximum tolerated " << std::setprecision(2) << tolerance_number << "%)");
+}
+
void validate(const IAccessor &tensor, const void *reference_value)
{
BOOST_TEST_REQUIRE((reference_value != nullptr));