diff options
Diffstat (limited to 'tests/validation/Validation.cpp')
-rw-r--r-- | tests/validation/Validation.cpp | 34 |
1 files changed, 33 insertions, 1 deletions
diff --git a/tests/validation/Validation.cpp b/tests/validation/Validation.cpp index eac4105b21..868bbaac5e 100644 --- a/tests/validation/Validation.cpp +++ b/tests/validation/Validation.cpp @@ -193,7 +193,6 @@ void check_single_element(const Coordinates &id, const IAccessor &tensor, const BOOST_TEST_INFO("reference = " << std::setprecision(5) << ref); BOOST_TEST_INFO("target = " << std::setprecision(5) << target); BOOST_TEST_WARN(equal); - ++num_mismatches; } ++num_elements; @@ -264,6 +263,39 @@ void validate(const IAccessor &tensor, const RawTensor &reference, const ValidRe << "%) mismatched (maximum tolerated " << std::setprecision(2) << tolerance_number << "%)"); } +void validate(const IAccessor &tensor, const RawTensor &reference, const RawTensor &valid_mask, float tolerance_value, float tolerance_number, uint64_t wrap_range) +{ + int64_t num_mismatches = 0; + int64_t num_elements = 0; + + BOOST_TEST(tensor.element_size() == reference.element_size()); + BOOST_TEST(tensor.format() == reference.format()); + BOOST_TEST(tensor.data_type() == reference.data_type()); + BOOST_TEST(tensor.num_channels() == reference.num_channels()); + BOOST_TEST(compare_dimensions(tensor.shape(), reference.shape())); + + const int min_elements = std::min(tensor.num_elements(), reference.num_elements()); + const int min_channels = std::min(tensor.num_channels(), reference.num_channels()); + const size_t channel_size = element_size_from_data_type(reference.data_type()); + + // Iterate over all elements within valid region, e.g. U8, S16, RGB888, ... + for(int element_idx = 0; element_idx < min_elements; ++element_idx) + { + const Coordinates id = index2coord(reference.shape(), element_idx); + if(valid_mask[element_idx] == 1) + { + check_single_element(id, tensor, reference, tolerance_value, wrap_range, min_channels, channel_size, num_mismatches, num_elements); + } + } + + const int64_t absolute_tolerance_number = tolerance_number * num_elements; + const float percent_mismatches = static_cast<float>(num_mismatches) / num_elements * 100.f; + + BOOST_TEST(num_mismatches <= absolute_tolerance_number, + num_mismatches << " values (" << std::setprecision(2) << percent_mismatches + << "%) mismatched (maximum tolerated " << std::setprecision(2) << tolerance_number << "%)"); +} + void validate(const IAccessor &tensor, const void *reference_value) { BOOST_TEST_REQUIRE((reference_value != nullptr)); |