From 563494c2f447e201e88e6d7133a41e12971777eb Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Mon, 30 Apr 2018 17:29:41 +0100 Subject: COMPMID-1084 Rework the way validation is performed for NHWC data layout Change-Id: I00b95f560548da76718298b642c8166f92421097 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129520 Tested-by: Jenkins Reviewed-by: Michele DiGiorgio Reviewed-by: Anthony Barbier --- tests/validation/Validation.h | 72 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 58 insertions(+), 14 deletions(-) (limited to 'tests/validation/Validation.h') diff --git a/tests/validation/Validation.h b/tests/validation/Validation.h index 508fb027ca..ac3643ea4a 100644 --- a/tests/validation/Validation.h +++ b/tests/validation/Validation.h @@ -137,19 +137,45 @@ inline ::std::ostream &operator<<(::std::ostream &os, const RelativeTolerance } template -bool compare_dimensions(const Dimensions &dimensions1, const Dimensions &dimensions2) +bool compare_dimensions(const Dimensions &dimensions1, const Dimensions &dimensions2, const DataLayout &data_layout = DataLayout::NCHW) { - if(dimensions1.num_dimensions() != dimensions2.num_dimensions()) + ARM_COMPUTE_ERROR_ON(data_layout == DataLayout::UNKNOWN); + + if(data_layout == DataLayout::NCHW) { - return false; - } + if(dimensions1.num_dimensions() != dimensions2.num_dimensions()) + { + return false; + } - for(unsigned int i = 0; i < dimensions1.num_dimensions(); ++i) + for(unsigned int i = 0; i < dimensions1.num_dimensions(); ++i) + { + if(dimensions1[i] != dimensions2[i]) + { + return false; + } + } + } + else { - if(dimensions1[i] != dimensions2[i]) + // In case a 2D shape becomes 3D after permutation, the permuted tensor will have one dimension more and the first value will be 1 + if((dimensions1.num_dimensions() != dimensions2.num_dimensions()) && ((dimensions1.num_dimensions() != (dimensions2.num_dimensions() + 1)) || (dimensions1.x() != 1))) + { + return false; + } + + if((dimensions1[0] != dimensions2[2]) || (dimensions1[1] != dimensions2[0]) || (dimensions1[2] != dimensions2[1])) { return false; } + + for(unsigned int i = 3; i < dimensions1.num_dimensions(); ++i) + { + if(dimensions1[i] != dimensions2[i]) + { + return false; + } + } } return true; @@ -342,14 +368,14 @@ template void validate(const IAccessor &tensor, const SimpleTensor &reference, U tolerance_value, float tolerance_number, float absolute_tolerance_value) { // Validate with valid region covering the entire shape - validate(tensor, reference, shape_to_valid_region(tensor.shape()), tolerance_value, tolerance_number, absolute_tolerance_value); + validate(tensor, reference, shape_to_valid_region(reference.shape()), tolerance_value, tolerance_number, absolute_tolerance_value); } template ::value>::type> void validate_wrap(const IAccessor &tensor, const SimpleTensor &reference, U tolerance_value, float tolerance_number) { // Validate with valid region covering the entire shape - validate_wrap(tensor, reference, shape_to_valid_region(tensor.shape()), tolerance_value, tolerance_number); + validate_wrap(tensor, reference, shape_to_valid_region(reference.shape()), tolerance_value, tolerance_number); } template @@ -367,7 +393,7 @@ void validate(const IAccessor &tensor, const SimpleTensor &reference, const V } ARM_COMPUTE_EXPECT_EQUAL(tensor.num_channels(), reference.num_channels(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(compare_dimensions(tensor.shape(), reference.shape()), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(compare_dimensions(tensor.shape(), reference.shape(), tensor.data_layout()), framework::LogLevel::ERRORS); const int min_elements = std::min(tensor.num_elements(), reference.num_elements()); const int min_channels = std::min(tensor.num_channels(), reference.num_channels()); @@ -377,12 +403,18 @@ void validate(const IAccessor &tensor, const SimpleTensor &reference, const V { const Coordinates id = index2coord(reference.shape(), element_idx); + Coordinates target_id(id); + if(tensor.data_layout() == DataLayout::NHWC) + { + permute(target_id, PermutationVector(2U, 0U, 1U)); + } + if(is_in_valid_region(valid_region, id)) { // Iterate over all channels within one element for(int c = 0; c < min_channels; ++c) { - const T &target_value = reinterpret_cast(tensor(id))[c]; + const T &target_value = reinterpret_cast(tensor(target_id))[c]; const T &reference_value = reinterpret_cast(reference(id))[c]; if(!compare(target_value, reference_value, tolerance_value)) @@ -436,7 +468,7 @@ void validate_wrap(const IAccessor &tensor, const SimpleTensor &reference, co } ARM_COMPUTE_EXPECT_EQUAL(tensor.num_channels(), reference.num_channels(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(compare_dimensions(tensor.shape(), reference.shape()), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(compare_dimensions(tensor.shape(), reference.shape(), tensor.data_layout()), framework::LogLevel::ERRORS); const int min_elements = std::min(tensor.num_elements(), reference.num_elements()); const int min_channels = std::min(tensor.num_channels(), reference.num_channels()); @@ -446,12 +478,18 @@ void validate_wrap(const IAccessor &tensor, const SimpleTensor &reference, co { const Coordinates id = index2coord(reference.shape(), element_idx); + Coordinates target_id(id); + if(tensor.data_layout() == DataLayout::NHWC) + { + permute(target_id, PermutationVector(2U, 0U, 1U)); + } + if(is_in_valid_region(valid_region, id)) { // Iterate over all channels within one element for(int c = 0; c < min_channels; ++c) { - const T &target_value = reinterpret_cast(tensor(id))[c]; + const T &target_value = reinterpret_cast(tensor(target_id))[c]; const T &reference_value = reinterpret_cast(reference(id))[c]; bool equal = compare(target_value, reference_value, tolerance_value); @@ -518,7 +556,7 @@ void validate(const IAccessor &tensor, const SimpleTensor &reference, const S } ARM_COMPUTE_EXPECT_EQUAL(tensor.num_channels(), reference.num_channels(), framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(compare_dimensions(tensor.shape(), reference.shape()), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(compare_dimensions(tensor.shape(), reference.shape(), tensor.data_layout()), framework::LogLevel::ERRORS); const int min_elements = std::min(tensor.num_elements(), reference.num_elements()); const int min_channels = std::min(tensor.num_channels(), reference.num_channels()); @@ -528,12 +566,18 @@ void validate(const IAccessor &tensor, const SimpleTensor &reference, const S { const Coordinates id = index2coord(reference.shape(), element_idx); + Coordinates target_id(id); + if(tensor.data_layout() == DataLayout::NHWC) + { + permute(target_id, PermutationVector(2U, 0U, 1U)); + } + if(valid_mask[element_idx] == 1) { // Iterate over all channels within one element for(int c = 0; c < min_channels; ++c) { - const T &target_value = reinterpret_cast(tensor(id))[c]; + const T &target_value = reinterpret_cast(tensor(target_id))[c]; const T &reference_value = reinterpret_cast(reference(id))[c]; if(!compare(target_value, reference_value, tolerance_value)) -- cgit v1.2.1