From 83be745adba7a9928c03beda65a6a83f14846475 Mon Sep 17 00:00:00 2001 From: Isabella Gottardi Date: Tue, 29 Aug 2017 13:47:03 +0100 Subject: COMPMID-424 Implemented reference implementation and tests for WarpAffine Change-Id: I4924ab1de17adc3b880a5cc22f2497abbc8e221b Reviewed-on: http://mpd-gerrit.cambridge.arm.com/85820 Tested-by: Kaizen Reviewed-by: Steven Niu --- tests/validation/Validation.h | 78 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) (limited to 'tests/validation/Validation.h') diff --git a/tests/validation/Validation.h b/tests/validation/Validation.h index 5e5dab0040..f220224991 100644 --- a/tests/validation/Validation.h +++ b/tests/validation/Validation.h @@ -188,6 +188,19 @@ void validate(const IAccessor &tensor, const SimpleTensor &reference, U toler template > void validate(const IAccessor &tensor, const SimpleTensor &reference, const ValidRegion &valid_region, U tolerance_value = U(), float tolerance_number = 0.f); +/** Validate tensors with valid mask. + * + * - Dimensionality has to be the same. + * - All values have to match. + * + * @note: wrap_range allows cases where reference tensor rounds up to the wrapping point, causing it to wrap around to + * zero while the test tensor stays at wrapping point to pass. This may permit true erroneous cases (difference between + * reference tensor and test tensor is multiple of wrap_range), but such errors would be detected by + * other test cases. + */ +template > +void validate(const IAccessor &tensor, const SimpleTensor &reference, const SimpleTensor &valid_mask, U tolerance_value = U(), float tolerance_number = 0.f); + /** Validate tensors against constant value. * * - All values have to match. @@ -434,6 +447,71 @@ void validate_keypoints(T target_first, T target_last, U reference_first, U refe } } +template +void validate(const IAccessor &tensor, const SimpleTensor &reference, const SimpleTensor &valid_mask, U tolerance_value, float tolerance_number) +{ + int64_t num_mismatches = 0; + int64_t num_elements = 0; + + ARM_COMPUTE_EXPECT_EQUAL(tensor.element_size(), reference.element_size(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT_EQUAL(tensor.data_type(), reference.data_type(), framework::LogLevel::ERRORS); + + if(reference.format() != Format::UNKNOWN) + { + ARM_COMPUTE_EXPECT_EQUAL(tensor.format(), reference.format(), framework::LogLevel::ERRORS); + } + + ARM_COMPUTE_EXPECT_EQUAL(tensor.num_channels(), reference.num_channels(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(compare_dimensions(tensor.shape(), reference.shape()), framework::LogLevel::ERRORS); + + const int min_elements = std::min(tensor.num_elements(), reference.num_elements()); + const int min_channels = std::min(tensor.num_channels(), reference.num_channels()); + + // Iterate over all elements within valid region, e.g. U8, S16, RGB888, ... + for(int element_idx = 0; element_idx < min_elements; ++element_idx) + { + const Coordinates id = index2coord(reference.shape(), element_idx); + + if(valid_mask[element_idx] == 1) + { + // Iterate over all channels within one element + for(int c = 0; c < min_channels; ++c) + { + const T &target_value = reinterpret_cast(tensor(id))[c]; + const T &reference_value = reinterpret_cast(reference(id))[c]; + + if(!compare(target_value, reference_value, tolerance_value)) + { + ARM_COMPUTE_TEST_INFO("id = " << id); + ARM_COMPUTE_TEST_INFO("channel = " << c); + ARM_COMPUTE_TEST_INFO("target = " << std::setprecision(5) << framework::make_printable(target_value)); + ARM_COMPUTE_TEST_INFO("reference = " << std::setprecision(5) << framework::make_printable(reference_value)); + ARM_COMPUTE_TEST_INFO("tolerance = " << std::setprecision(5) << framework::make_printable(static_cast(tolerance_value))); + framework::ARM_COMPUTE_PRINT_INFO(); + + ++num_mismatches; + } + + ++num_elements; + } + } + else + { + ++num_elements; + } + } + + if(num_elements > 0) + { + const int64_t absolute_tolerance_number = tolerance_number * num_elements; + const float percent_mismatches = static_cast(num_mismatches) / num_elements * 100.f; + + ARM_COMPUTE_TEST_INFO(num_mismatches << " values (" << std::fixed << std::setprecision(2) << percent_mismatches + << "%) mismatched (maximum tolerated " << std::setprecision(2) << tolerance_number << "%)"); + ARM_COMPUTE_EXPECT(num_mismatches <= absolute_tolerance_number, framework::LogLevel::ERRORS); + } +} + template bool validate(T target, T reference, U tolerance) { -- cgit v1.2.1