diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2017-06-19 16:11:53 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-09-17 14:14:20 +0100 |
commit | ce093143ec7b554edefc533c90e45c80946cde51 (patch) | |
tree | 1e4aa13ba3fe10c93ca42e2f5477bd2c4888324e /tests/Utils.h | |
parent | 4c2938ed50a78753bfbdbb2f3cbf43f5fed779f9 (diff) | |
download | ComputeLibrary-ce093143ec7b554edefc533c90e45c80946cde51.tar.gz |
COMPMID-403:Add support for 7x7 pooling on CL.
Change-Id: I3c2c8d7e8e61d7737170cb1568900ce4ac337068
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/78181
Reviewed-by: Michele DiGiorgio <michele.digiorgio@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Reviewed-by: Moritz Pflanzer <moritz.pflanzer@arm.com>
Diffstat (limited to 'tests/Utils.h')
-rw-r--r-- | tests/Utils.h | 40 |
1 files changed, 37 insertions, 3 deletions
diff --git a/tests/Utils.h b/tests/Utils.h index 53f749df48..f3622cafaa 100644 --- a/tests/Utils.h +++ b/tests/Utils.h @@ -615,13 +615,47 @@ inline int coord2index(const TensorShape &shape, const Coordinates &coord) return index; } +/** Check if Coordinates dimensionality can match the respective shape one. + * + * @param coords Coordinates + * @param shape Shape to match dimensionality + * + * @return True if Coordinates can match the dimensionality of the shape else false. + */ +inline bool match_shape(Coordinates &coords, const TensorShape &shape) +{ + auto check_nz = [](unsigned int i) + { + return i != 0; + }; + + unsigned int coords_dims = coords.num_dimensions(); + unsigned int shape_dims = shape.num_dimensions(); + + // Increase coordinates scenario + if(coords_dims < shape_dims) + { + coords.set_num_dimensions(shape_dims); + return true; + } + // Decrease coordinates scenario + if(coords_dims > shape_dims && !std::any_of(coords.begin() + shape_dims, coords.end(), check_nz)) + { + coords.set_num_dimensions(shape_dims); + return true; + } + + return (coords_dims == shape_dims); +} + /** Check if a coordinate is within a valid region */ inline bool is_in_valid_region(const ValidRegion &valid_region, const Coordinates &coord) { - ARM_COMPUTE_ERROR_ON_MSG(valid_region.shape.num_dimensions() != coord.num_dimensions(), "Shapes of valid region and coordinates do not agree"); - for(int d = 0; static_cast<size_t>(d) < coord.num_dimensions(); ++d) + Coordinates coords(coord); + ARM_COMPUTE_ERROR_ON_MSG(!match_shape(coords, valid_region.shape), "Shapes of valid region and coordinates do not agree"); + for(int d = 0; static_cast<size_t>(d) < coords.num_dimensions(); ++d) { - if(coord[d] < valid_region.start(d) || coord[d] >= valid_region.end(d)) + if(coords[d] < valid_region.start(d) || coords[d] >= valid_region.end(d)) { return false; } |