aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/Validation.h
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2018-04-30 17:29:41 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:51:17 +0000
commit563494c2f447e201e88e6d7133a41e12971777eb (patch)
tree716ae5e4978ce378ad14f53591087a7a42f6fe58 /tests/validation/Validation.h
parentb7f5d172ccdb1d884388dd6e0e54f74241afca67 (diff)
downloadComputeLibrary-563494c2f447e201e88e6d7133a41e12971777eb.tar.gz
COMPMID-1084 Rework the way validation is performed for NHWC data layout
Change-Id: I00b95f560548da76718298b642c8166f92421097 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129520 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Michele DiGiorgio <michele.digiorgio@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'tests/validation/Validation.h')
-rw-r--r--tests/validation/Validation.h72
1 files changed, 58 insertions, 14 deletions
diff --git a/tests/validation/Validation.h b/tests/validation/Validation.h
index 508fb027ca..ac3643ea4a 100644
--- a/tests/validation/Validation.h
+++ b/tests/validation/Validation.h
@@ -137,19 +137,45 @@ inline ::std::ostream &operator<<(::std::ostream &os, const RelativeTolerance<T>
}
template <typename T>
-bool compare_dimensions(const Dimensions<T> &dimensions1, const Dimensions<T> &dimensions2)
+bool compare_dimensions(const Dimensions<T> &dimensions1, const Dimensions<T> &dimensions2, const DataLayout &data_layout = DataLayout::NCHW)
{
- if(dimensions1.num_dimensions() != dimensions2.num_dimensions())
+ ARM_COMPUTE_ERROR_ON(data_layout == DataLayout::UNKNOWN);
+
+ if(data_layout == DataLayout::NCHW)
{
- return false;
- }
+ if(dimensions1.num_dimensions() != dimensions2.num_dimensions())
+ {
+ return false;
+ }
- for(unsigned int i = 0; i < dimensions1.num_dimensions(); ++i)
+ for(unsigned int i = 0; i < dimensions1.num_dimensions(); ++i)
+ {
+ if(dimensions1[i] != dimensions2[i])
+ {
+ return false;
+ }
+ }
+ }
+ else
{
- if(dimensions1[i] != dimensions2[i])
+ // In case a 2D shape becomes 3D after permutation, the permuted tensor will have one dimension more and the first value will be 1
+ if((dimensions1.num_dimensions() != dimensions2.num_dimensions()) && ((dimensions1.num_dimensions() != (dimensions2.num_dimensions() + 1)) || (dimensions1.x() != 1)))
+ {
+ return false;
+ }
+
+ if((dimensions1[0] != dimensions2[2]) || (dimensions1[1] != dimensions2[0]) || (dimensions1[2] != dimensions2[1]))
{
return false;
}
+
+ for(unsigned int i = 3; i < dimensions1.num_dimensions(); ++i)
+ {
+ if(dimensions1[i] != dimensions2[i])
+ {
+ return false;
+ }
+ }
}
return true;
@@ -342,14 +368,14 @@ template <typename T, typename U>
void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, U tolerance_value, float tolerance_number, float absolute_tolerance_value)
{
// Validate with valid region covering the entire shape
- validate(tensor, reference, shape_to_valid_region(tensor.shape()), tolerance_value, tolerance_number, absolute_tolerance_value);
+ validate(tensor, reference, shape_to_valid_region(reference.shape()), tolerance_value, tolerance_number, absolute_tolerance_value);
}
template <typename T, typename U, typename = typename std::enable_if<std::is_integral<T>::value>::type>
void validate_wrap(const IAccessor &tensor, const SimpleTensor<T> &reference, U tolerance_value, float tolerance_number)
{
// Validate with valid region covering the entire shape
- validate_wrap(tensor, reference, shape_to_valid_region(tensor.shape()), tolerance_value, tolerance_number);
+ validate_wrap(tensor, reference, shape_to_valid_region(reference.shape()), tolerance_value, tolerance_number);
}
template <typename T, typename U>
@@ -367,7 +393,7 @@ void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const V
}
ARM_COMPUTE_EXPECT_EQUAL(tensor.num_channels(), reference.num_channels(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(compare_dimensions(tensor.shape(), reference.shape()), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(compare_dimensions(tensor.shape(), reference.shape(), tensor.data_layout()), framework::LogLevel::ERRORS);
const int min_elements = std::min(tensor.num_elements(), reference.num_elements());
const int min_channels = std::min(tensor.num_channels(), reference.num_channels());
@@ -377,12 +403,18 @@ void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const V
{
const Coordinates id = index2coord(reference.shape(), element_idx);
+ Coordinates target_id(id);
+ if(tensor.data_layout() == DataLayout::NHWC)
+ {
+ permute(target_id, PermutationVector(2U, 0U, 1U));
+ }
+
if(is_in_valid_region(valid_region, id))
{
// Iterate over all channels within one element
for(int c = 0; c < min_channels; ++c)
{
- const T &target_value = reinterpret_cast<const T *>(tensor(id))[c];
+ const T &target_value = reinterpret_cast<const T *>(tensor(target_id))[c];
const T &reference_value = reinterpret_cast<const T *>(reference(id))[c];
if(!compare<U>(target_value, reference_value, tolerance_value))
@@ -436,7 +468,7 @@ void validate_wrap(const IAccessor &tensor, const SimpleTensor<T> &reference, co
}
ARM_COMPUTE_EXPECT_EQUAL(tensor.num_channels(), reference.num_channels(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(compare_dimensions(tensor.shape(), reference.shape()), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(compare_dimensions(tensor.shape(), reference.shape(), tensor.data_layout()), framework::LogLevel::ERRORS);
const int min_elements = std::min(tensor.num_elements(), reference.num_elements());
const int min_channels = std::min(tensor.num_channels(), reference.num_channels());
@@ -446,12 +478,18 @@ void validate_wrap(const IAccessor &tensor, const SimpleTensor<T> &reference, co
{
const Coordinates id = index2coord(reference.shape(), element_idx);
+ Coordinates target_id(id);
+ if(tensor.data_layout() == DataLayout::NHWC)
+ {
+ permute(target_id, PermutationVector(2U, 0U, 1U));
+ }
+
if(is_in_valid_region(valid_region, id))
{
// Iterate over all channels within one element
for(int c = 0; c < min_channels; ++c)
{
- const T &target_value = reinterpret_cast<const T *>(tensor(id))[c];
+ const T &target_value = reinterpret_cast<const T *>(tensor(target_id))[c];
const T &reference_value = reinterpret_cast<const T *>(reference(id))[c];
bool equal = compare<U>(target_value, reference_value, tolerance_value);
@@ -518,7 +556,7 @@ void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const S
}
ARM_COMPUTE_EXPECT_EQUAL(tensor.num_channels(), reference.num_channels(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(compare_dimensions(tensor.shape(), reference.shape()), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(compare_dimensions(tensor.shape(), reference.shape(), tensor.data_layout()), framework::LogLevel::ERRORS);
const int min_elements = std::min(tensor.num_elements(), reference.num_elements());
const int min_channels = std::min(tensor.num_channels(), reference.num_channels());
@@ -528,12 +566,18 @@ void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const S
{
const Coordinates id = index2coord(reference.shape(), element_idx);
+ Coordinates target_id(id);
+ if(tensor.data_layout() == DataLayout::NHWC)
+ {
+ permute(target_id, PermutationVector(2U, 0U, 1U));
+ }
+
if(valid_mask[element_idx] == 1)
{
// Iterate over all channels within one element
for(int c = 0; c < min_channels; ++c)
{
- const T &target_value = reinterpret_cast<const T *>(tensor(id))[c];
+ const T &target_value = reinterpret_cast<const T *>(tensor(target_id))[c];
const T &reference_value = reinterpret_cast<const T *>(reference(id))[c];
if(!compare<U>(target_value, reference_value, tolerance_value))