diff options
Diffstat (limited to 'arm_compute/core/Helpers.h')
-rw-r--r-- | arm_compute/core/Helpers.h | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/arm_compute/core/Helpers.h b/arm_compute/core/Helpers.h index c6a7db4f96..63fad1dcea 100644 --- a/arm_compute/core/Helpers.h +++ b/arm_compute/core/Helpers.h @@ -508,10 +508,28 @@ inline Strides compute_strides(const ITensorInfo &info) template <typename T> inline void permute(Dimensions<T> &dimensions, const PermutationVector &perm) { - auto copy_dimensions = utility::make_array<Dimensions<T>::num_max_dimensions>(dimensions.begin(), dimensions.end()); + auto dimensions_copy = utility::make_array<Dimensions<T>::num_max_dimensions>(dimensions.begin(), dimensions.end()); for(unsigned int i = 0; i < perm.num_dimensions(); ++i) { - dimensions[i] = copy_dimensions[perm[i]]; + T dimension_val = (perm[i] < dimensions.num_dimensions()) ? dimensions_copy[perm[i]] : 0; + dimensions.set(i, dimension_val); + } +} + +/** Permutes given TensorShape according to a permutation vector + * + * @warning Validity of permutation is not checked + * + * @param[in, out] shape Shape to permute + * @param[in] perm Permutation vector + */ +inline void permute(TensorShape &shape, const PermutationVector &perm) +{ + auto shape_copy = utility::make_array<TensorShape::num_max_dimensions>(shape.begin(), shape.end()); + for(unsigned int i = 0; i < perm.num_dimensions(); ++i) + { + size_t dimension_val = (perm[i] < shape.num_dimensions()) ? shape_copy[perm[i]] : 1; + shape.set(i, dimension_val); } } |