From d8339a75c9b655c0507e34238078fdad068b4023 Mon Sep 17 00:00:00 2001 From: Tim Hall Date: Thu, 27 May 2021 18:49:40 +0100 Subject: MLBEDSW-4034: New Scheduler Size or Performance Optimisation - Merged dev/scheduler at 83639f90e8c828f70de6e29142355a940224959b Signed-off-by: Tim Hall Change-Id: I0050529d4b42da93768c7264296434dd877fb5b4 --- ethosu/vela/shape4d.py | 94 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) (limited to 'ethosu/vela/shape4d.py') diff --git a/ethosu/vela/shape4d.py b/ethosu/vela/shape4d.py index 5d849d98..fd674031 100644 --- a/ethosu/vela/shape4d.py +++ b/ethosu/vela/shape4d.py @@ -16,8 +16,10 @@ # Description: # Defines the class Shape4D. from collections import namedtuple +from enum import Enum from .numeric_util import full_shape +from .numeric_util import round_up from .numeric_util import round_up_divide @@ -41,6 +43,27 @@ class Shape4D(namedtuple("Shape4D", ["batch", "height", "width", "depth"])): tmp = full_shape(4, shape, base) return cls(tmp[0], tmp[1], tmp[2], tmp[3]) + @classmethod + def min(cls, lhs, rhs): + return Shape4D( + min(lhs.batch, rhs.batch), min(lhs.height, rhs.height), min(lhs.width, rhs.width), min(lhs.depth, rhs.depth) + ) + + @classmethod + def max(cls, lhs, rhs): + return Shape4D( + max(lhs.batch, rhs.batch), max(lhs.height, rhs.height), max(lhs.width, rhs.width), max(lhs.depth, rhs.depth) + ) + + @classmethod + def round_up(cls, lhs, rhs): + return Shape4D( + round_up(lhs.batch, rhs.batch), + round_up(lhs.height, rhs.height), + round_up(lhs.width, rhs.width), + round_up(lhs.depth, rhs.depth), + ) + @classmethod def from_hwc(cls, h, w, c): return cls(1, h, w, c) @@ -60,6 +83,25 @@ class Shape4D(namedtuple("Shape4D", ["batch", "height", "width", "depth"])): def with_depth(self, new_depth): return Shape4D(self.batch, self.height, self.width, new_depth) + def with_axis(self, axis, new_val): + shape_as_list = self.as_list() + shape_as_list[axis] = new_val + return Shape4D.from_list(shape_as_list) + + @staticmethod + def _clip_len(pos, length, size): + if pos < 0: + length = length + pos + pos = 0 + return min(pos + length, size) - pos + + def clip(self, offset, sub_shape): + n = Shape4D._clip_len(offset.batch, sub_shape.batch, self.batch) + h = Shape4D._clip_len(offset.height, sub_shape.height, self.height) + w = Shape4D._clip_len(offset.width, sub_shape.width, self.width) + c = Shape4D._clip_len(offset.depth, sub_shape.depth, self.depth) + return Shape4D(n, h, w, c) + def add(self, n, h, w, c): return Shape4D(self.batch + n, self.height + h, self.width + w, self.depth + c) @@ -74,6 +116,9 @@ class Shape4D(namedtuple("Shape4D", ["batch", "height", "width", "depth"])): self.batch // rhs.batch, self.height // rhs.height, self.width // rhs.width, self.depth // rhs.depth ) + def __truediv__(self, rhs): + return Shape4D(self.batch / rhs.batch, self.height / rhs.height, self.width / rhs.width, self.depth / rhs.depth) + def __mod__(self, rhs): return Shape4D(self.batch % rhs.batch, self.height % rhs.height, self.width % rhs.width, self.depth % rhs.depth) @@ -102,3 +147,52 @@ class Shape4D(namedtuple("Shape4D", ["batch", "height", "width", "depth"])): def get_hw_as_list(self): return list([self.height, self.width]) + + +class VolumeIterator: + """ + 4D Volume iterator. Use to traverse 4D tensor volumes in smaller shapes. + """ + + class Direction(Enum): + CWHN = 0 + + def __init__( + self, + shape: Shape4D, + sub_shape: Shape4D, + start: Shape4D = Shape4D(0, 0, 0, 0), + delta: Shape4D = None, + dir=Direction.CWHN, + ): + self.b = start.batch + self.y = start.height + self.x = start.width + self.z = start.depth + self.shape = shape + self.sub_shape = sub_shape + self.delta = sub_shape if delta is None else delta + assert self.delta.elements() > 0, "Iterator will not move" + + def __iter__(self): + return self + + def __next__(self): + if self.b >= self.shape.batch: + raise StopIteration() + + offset = Shape4D(self.b, self.y, self.x, self.z) + + # CWHN + self.z += self.delta.depth + if self.z >= self.shape.depth: + self.z = 0 + self.x += self.delta.width + if self.x >= self.shape.width: + self.x = 0 + self.y += self.delta.height + if self.y >= self.shape.height: + self.y = 0 + self.b += self.delta.batch + + return offset, self.shape.clip(offset, self.sub_shape) -- cgit v1.2.1