diff options
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r-- | ethosu/vela/tensor.py | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index 38b0e430..e9815845 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -823,6 +823,19 @@ class Tensor: else: return self.values.item(0) + def get_shape_as_2d(self, dimension_2_size: int) -> Optional[Shape4D]: + + elms = self.elements() + dimension_1_size = elms // dimension_2_size + # Checks if the reduction works and shape is not 1D + is_reducible = dimension_1_size * dimension_2_size == elms and not (len(self.shape) == 1) + + new_shape = None + if is_reducible: + new_shape = Shape4D([dimension_1_size, 1, 1, dimension_2_size]) + + return new_shape + def __lt__(self, other: "Tensor") -> bool: return self.equivalence_id < other.equivalence_id |