diff options
Diffstat (limited to 'tests/Utils.h')
-rw-r--r-- | tests/Utils.h | 40 |
1 files changed, 37 insertions, 3 deletions
diff --git a/tests/Utils.h b/tests/Utils.h index 53f749df48..f3622cafaa 100644 --- a/tests/Utils.h +++ b/tests/Utils.h @@ -615,13 +615,47 @@ inline int coord2index(const TensorShape &shape, const Coordinates &coord) return index; } +/** Check if Coordinates dimensionality can match the respective shape one. + * + * @param coords Coordinates + * @param shape Shape to match dimensionality + * + * @return True if Coordinates can match the dimensionality of the shape else false. + */ +inline bool match_shape(Coordinates &coords, const TensorShape &shape) +{ + auto check_nz = [](unsigned int i) + { + return i != 0; + }; + + unsigned int coords_dims = coords.num_dimensions(); + unsigned int shape_dims = shape.num_dimensions(); + + // Increase coordinates scenario + if(coords_dims < shape_dims) + { + coords.set_num_dimensions(shape_dims); + return true; + } + // Decrease coordinates scenario + if(coords_dims > shape_dims && !std::any_of(coords.begin() + shape_dims, coords.end(), check_nz)) + { + coords.set_num_dimensions(shape_dims); + return true; + } + + return (coords_dims == shape_dims); +} + /** Check if a coordinate is within a valid region */ inline bool is_in_valid_region(const ValidRegion &valid_region, const Coordinates &coord) { - ARM_COMPUTE_ERROR_ON_MSG(valid_region.shape.num_dimensions() != coord.num_dimensions(), "Shapes of valid region and coordinates do not agree"); - for(int d = 0; static_cast<size_t>(d) < coord.num_dimensions(); ++d) + Coordinates coords(coord); + ARM_COMPUTE_ERROR_ON_MSG(!match_shape(coords, valid_region.shape), "Shapes of valid region and coordinates do not agree"); + for(int d = 0; static_cast<size_t>(d) < coords.num_dimensions(); ++d) { - if(coord[d] < valid_region.start(d) || coord[d] >= valid_region.end(d)) + if(coords[d] < valid_region.start(d) || coords[d] >= valid_region.end(d)) { return false; } |