aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/shape4d.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/shape4d.py')
-rw-r--r--ethosu/vela/shape4d.py102
1 files changed, 63 insertions, 39 deletions
diff --git a/ethosu/vela/shape4d.py b/ethosu/vela/shape4d.py
index 8981e20b..e26389a1 100644
--- a/ethosu/vela/shape4d.py
+++ b/ethosu/vela/shape4d.py
@@ -15,66 +15,90 @@
# limitations under the License.
# Description:
# Defines the class Shape4D.
+from collections import namedtuple
+
from .numeric_util import full_shape
+from .numeric_util import round_up_divide
-class Shape4D:
+class Shape4D(namedtuple("Shape4D", ["batch", "height", "width", "depth"])):
"""
4D Shape (in NHWC format)
"""
- def __init__(self, shape, base=1):
- assert shape is not None
- assert len(shape) <= 4
- self._shape4D = tuple(full_shape(4, shape, base))
+ def __new__(cls, n=1, h=1, w=1, c=1):
+ assert n is not None
+ if isinstance(n, list):
+ assert h == 1 and w == 1 and c == 1
+ tmp = full_shape(4, n, 1)
+ self = super(Shape4D, cls).__new__(cls, tmp[0], tmp[1], tmp[2], tmp[3])
+ else:
+ self = super(Shape4D, cls).__new__(cls, n, h, w, c)
+ return self
- def __str__(self):
- return f"<Shape4D {self.as_list()}>"
+ @classmethod
+ def from_list(cls, shape, base=1):
+ tmp = full_shape(4, shape, base)
+ return cls(tmp[0], tmp[1], tmp[2], tmp[3])
+
+ @classmethod
+ def from_hwc(cls, h, w, c):
+ return cls(1, h, w, c)
+
+ def with_batch(self, new_batch):
+ return Shape4D(new_batch, self.height, self.width, self.depth)
- def __eq__(self, other):
- return self._shape4D == other._shape4D
+ def with_height(self, new_height):
+ return Shape4D(self.batch, new_height, self.width, self.depth)
- def clone(self):
- return Shape4D(self.as_list())
+ def with_width(self, new_width):
+ return Shape4D(self.batch, self.height, new_width, self.depth)
- @property
- def batch(self):
- return self._shape4D[0]
+ def with_hw(self, new_height, new_width):
+ return Shape4D(self.batch, new_height, new_width, self.depth)
- @property
- def height(self):
- return self._shape4D[1]
+ def with_depth(self, new_depth):
+ return Shape4D(self.batch, self.height, self.width, new_depth)
- @property
- def width(self):
- return self._shape4D[2]
+ def add(self, n, h, w, c):
+ return Shape4D(self.batch + n, self.height + h, self.width + w, self.depth + c)
- @property
- def depth(self):
- return self._shape4D[3]
+ def __add__(self, rhs):
+ return Shape4D(self.batch + rhs.batch, self.height + rhs.height, self.width + rhs.width, self.depth + rhs.depth)
- @batch.setter
- def batch(self, new_batch):
- self._shape4D = (new_batch, self._shape4D[1], self._shape4D[2], self._shape4D[3])
+ def __sub__(self, rhs):
+ return Shape4D(self.batch - rhs.batch, self.height - rhs.height, self.width - rhs.width, self.depth - rhs.depth)
+
+ def __floordiv__(self, rhs):
+ return Shape4D(
+ self.batch // rhs.batch, self.height // rhs.height, self.width // rhs.width, self.depth // rhs.depth
+ )
+
+ def __mod__(self, rhs):
+ return Shape4D(self.batch % rhs.batch, self.height % rhs.height, self.width % rhs.width, self.depth % rhs.depth)
+
+ def __str__(self):
+ return f"<Shape4D {list(self)}>"
- @height.setter
- def height(self, new_height):
- self._shape4D = (self._shape4D[0], new_height, self._shape4D[2], self._shape4D[3])
+ def div_round_up(self, rhs):
+ return Shape4D(
+ round_up_divide(self.batch, rhs.batch),
+ round_up_divide(self.height, rhs.height),
+ round_up_divide(self.width, rhs.width),
+ round_up_divide(self.depth, rhs.depth),
+ )
- @width.setter
- def width(self, new_width):
- self._shape4D = (self._shape4D[0], self._shape4D[1], new_width, self._shape4D[3])
+ def elements(self):
+ return self.batch * self.width * self.height * self.depth
- @depth.setter
- def depth(self, new_depth):
- self._shape4D = (self._shape4D[0], self._shape4D[1], self._shape4D[2], new_depth)
+ def elements_wh(self):
+ return self.width * self.height
- def get_dim(self, dim):
- assert -4 <= dim < 4
- return self._shape4D[dim]
+ def is_empty(self):
+ return (self.batch + self.width + self.height + self.depth) == 0
def as_list(self):
- return list(self._shape4D)
+ return list(self)
def get_hw_as_list(self):
return list([self.height, self.width])