aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/Helpers.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/Helpers.h')
-rw-r--r--arm_compute/core/Helpers.h22
1 files changed, 20 insertions, 2 deletions
diff --git a/arm_compute/core/Helpers.h b/arm_compute/core/Helpers.h
index c6a7db4f96..63fad1dcea 100644
--- a/arm_compute/core/Helpers.h
+++ b/arm_compute/core/Helpers.h
@@ -508,10 +508,28 @@ inline Strides compute_strides(const ITensorInfo &info)
template <typename T>
inline void permute(Dimensions<T> &dimensions, const PermutationVector &perm)
{
- auto copy_dimensions = utility::make_array<Dimensions<T>::num_max_dimensions>(dimensions.begin(), dimensions.end());
+ auto dimensions_copy = utility::make_array<Dimensions<T>::num_max_dimensions>(dimensions.begin(), dimensions.end());
for(unsigned int i = 0; i < perm.num_dimensions(); ++i)
{
- dimensions[i] = copy_dimensions[perm[i]];
+ T dimension_val = (perm[i] < dimensions.num_dimensions()) ? dimensions_copy[perm[i]] : 0;
+ dimensions.set(i, dimension_val);
+ }
+}
+
+/** Permutes given TensorShape according to a permutation vector
+ *
+ * @warning Validity of permutation is not checked
+ *
+ * @param[in, out] shape Shape to permute
+ * @param[in] perm Permutation vector
+ */
+inline void permute(TensorShape &shape, const PermutationVector &perm)
+{
+ auto shape_copy = utility::make_array<TensorShape::num_max_dimensions>(shape.begin(), shape.end());
+ for(unsigned int i = 0; i < perm.num_dimensions(); ++i)
+ {
+ size_t dimension_val = (perm[i] < shape.num_dimensions()) ? shape_copy[perm[i]] : 1;
+ shape.set(i, dimension_val);
}
}