diff options
author | Tim Hall <tim.hall@arm.com> | 2021-02-04 22:47:46 +0000 |
---|---|---|
committer | Tim Hall <tim.hall@arm.com> | 2021-02-05 11:30:49 +0000 |
commit | 73e843f76dd71e4ab5e07a7616c2c4806ca6ac25 (patch) | |
tree | 73c35c5443e041441ba826cacfc12f21d5b30bac /ethosu/vela/shape4d.py | |
parent | 133ba7e39c9517d43690c55197d71733ad0dc38c (diff) | |
download | ethos-u-vela-73e843f76dd71e4ab5e07a7616c2c4806ca6ac25.tar.gz |
vela: Change Shape4D mutability usage
- Removed requirement for cloning shapes when unique values required
by forcing top-level immutability. This alleviates issues with Shapes
being unintentionally shared and then mutated as if value-types.
- Shape4D fields can no longer be assigned without replication.
Signed-off-by: Tim Hall <tim.hall@arm.com>
Change-Id: Ic0dbfa349eb0215eabefb4f4e2cf99f12d83699c
Diffstat (limited to 'ethosu/vela/shape4d.py')
-rw-r--r-- | ethosu/vela/shape4d.py | 102 |
1 files changed, 63 insertions, 39 deletions
diff --git a/ethosu/vela/shape4d.py b/ethosu/vela/shape4d.py index 8981e20b..e26389a1 100644 --- a/ethosu/vela/shape4d.py +++ b/ethosu/vela/shape4d.py @@ -15,66 +15,90 @@ # limitations under the License. # Description: # Defines the class Shape4D. +from collections import namedtuple + from .numeric_util import full_shape +from .numeric_util import round_up_divide -class Shape4D: +class Shape4D(namedtuple("Shape4D", ["batch", "height", "width", "depth"])): """ 4D Shape (in NHWC format) """ - def __init__(self, shape, base=1): - assert shape is not None - assert len(shape) <= 4 - self._shape4D = tuple(full_shape(4, shape, base)) + def __new__(cls, n=1, h=1, w=1, c=1): + assert n is not None + if isinstance(n, list): + assert h == 1 and w == 1 and c == 1 + tmp = full_shape(4, n, 1) + self = super(Shape4D, cls).__new__(cls, tmp[0], tmp[1], tmp[2], tmp[3]) + else: + self = super(Shape4D, cls).__new__(cls, n, h, w, c) + return self - def __str__(self): - return f"<Shape4D {self.as_list()}>" + @classmethod + def from_list(cls, shape, base=1): + tmp = full_shape(4, shape, base) + return cls(tmp[0], tmp[1], tmp[2], tmp[3]) + + @classmethod + def from_hwc(cls, h, w, c): + return cls(1, h, w, c) + + def with_batch(self, new_batch): + return Shape4D(new_batch, self.height, self.width, self.depth) - def __eq__(self, other): - return self._shape4D == other._shape4D + def with_height(self, new_height): + return Shape4D(self.batch, new_height, self.width, self.depth) - def clone(self): - return Shape4D(self.as_list()) + def with_width(self, new_width): + return Shape4D(self.batch, self.height, new_width, self.depth) - @property - def batch(self): - return self._shape4D[0] + def with_hw(self, new_height, new_width): + return Shape4D(self.batch, new_height, new_width, self.depth) - @property - def height(self): - return self._shape4D[1] + def with_depth(self, new_depth): + return Shape4D(self.batch, self.height, self.width, new_depth) - @property - def width(self): - return self._shape4D[2] + def add(self, n, h, w, c): + return Shape4D(self.batch + n, self.height + h, self.width + w, self.depth + c) - @property - def depth(self): - return self._shape4D[3] + def __add__(self, rhs): + return Shape4D(self.batch + rhs.batch, self.height + rhs.height, self.width + rhs.width, self.depth + rhs.depth) - @batch.setter - def batch(self, new_batch): - self._shape4D = (new_batch, self._shape4D[1], self._shape4D[2], self._shape4D[3]) + def __sub__(self, rhs): + return Shape4D(self.batch - rhs.batch, self.height - rhs.height, self.width - rhs.width, self.depth - rhs.depth) + + def __floordiv__(self, rhs): + return Shape4D( + self.batch // rhs.batch, self.height // rhs.height, self.width // rhs.width, self.depth // rhs.depth + ) + + def __mod__(self, rhs): + return Shape4D(self.batch % rhs.batch, self.height % rhs.height, self.width % rhs.width, self.depth % rhs.depth) + + def __str__(self): + return f"<Shape4D {list(self)}>" - @height.setter - def height(self, new_height): - self._shape4D = (self._shape4D[0], new_height, self._shape4D[2], self._shape4D[3]) + def div_round_up(self, rhs): + return Shape4D( + round_up_divide(self.batch, rhs.batch), + round_up_divide(self.height, rhs.height), + round_up_divide(self.width, rhs.width), + round_up_divide(self.depth, rhs.depth), + ) - @width.setter - def width(self, new_width): - self._shape4D = (self._shape4D[0], self._shape4D[1], new_width, self._shape4D[3]) + def elements(self): + return self.batch * self.width * self.height * self.depth - @depth.setter - def depth(self, new_depth): - self._shape4D = (self._shape4D[0], self._shape4D[1], self._shape4D[2], new_depth) + def elements_wh(self): + return self.width * self.height - def get_dim(self, dim): - assert -4 <= dim < 4 - return self._shape4D[dim] + def is_empty(self): + return (self.batch + self.width + self.height + self.depth) == 0 def as_list(self): - return list(self._shape4D) + return list(self) def get_hw_as_list(self): return list([self.height, self.width]) |