aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/Validation.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/Validation.h')
-rw-r--r--tests/validation/Validation.h8
1 files changed, 6 insertions, 2 deletions
diff --git a/tests/validation/Validation.h b/tests/validation/Validation.h
index c2df1c31c0..a75562bac2 100644
--- a/tests/validation/Validation.h
+++ b/tests/validation/Validation.h
@@ -157,11 +157,15 @@ bool compare_dimensions(const Dimensions<T> &dimensions1, const Dimensions<T> &d
}
else
{
- // In case a 2D shape becomes 3D after permutation, the permuted tensor will have one dimension more and the first value will be 1
- if((dimensions1.num_dimensions() != dimensions2.num_dimensions()) && ((dimensions1.num_dimensions() != (dimensions2.num_dimensions() + 1)) || (dimensions1.x() != 1)))
+ // In case a 1D/2D shape becomes 3D after permutation, the permuted tensor will have two/one dimension(s) more and the first (two) value(s) will be 1
+ // clang-format off
+ if((dimensions1.num_dimensions() != dimensions2.num_dimensions()) &&
+ ((dimensions1.num_dimensions() != (dimensions2.num_dimensions() + 1)) || (dimensions1.x() != 1)) &&
+ ((dimensions1.num_dimensions() != (dimensions2.num_dimensions() + 2)) || (dimensions1.x() != 1) || (dimensions1.y() != 1)))
{
return false;
}
+ // clang-format on
if((dimensions1[0] != dimensions2[2]) || (dimensions1[1] != dimensions2[0]) || (dimensions1[2] != dimensions2[1]))
{