aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/shape4d.py
diff options
context:
space:
mode:
authorTim Hall <tim.hall@arm.com>2021-05-27 18:49:40 +0100
committerTim Hall <tim.hall@arm.com>2021-05-27 18:57:39 +0100
commitd8339a75c9b655c0507e34238078fdad068b4023 (patch)
tree36a14726b30760169a83c0356803b480992fade8 /ethosu/vela/shape4d.py
parent64556f32ff7bfca6036a6598034464b13b64a4ef (diff)
downloadethos-u-vela-d8339a75c9b655c0507e34238078fdad068b4023.tar.gz
MLBEDSW-4034: New Scheduler Size or Performance Optimisation
- Merged dev/scheduler at 83639f90e8c828f70de6e29142355a940224959b Signed-off-by: Tim Hall <tim.hall@arm.com> Change-Id: I0050529d4b42da93768c7264296434dd877fb5b4
Diffstat (limited to 'ethosu/vela/shape4d.py')
-rw-r--r--ethosu/vela/shape4d.py94
1 files changed, 94 insertions, 0 deletions
diff --git a/ethosu/vela/shape4d.py b/ethosu/vela/shape4d.py
index 5d849d98..fd674031 100644
--- a/ethosu/vela/shape4d.py
+++ b/ethosu/vela/shape4d.py
@@ -16,8 +16,10 @@
# Description:
# Defines the class Shape4D.
from collections import namedtuple
+from enum import Enum
from .numeric_util import full_shape
+from .numeric_util import round_up
from .numeric_util import round_up_divide
@@ -42,6 +44,27 @@ class Shape4D(namedtuple("Shape4D", ["batch", "height", "width", "depth"])):
return cls(tmp[0], tmp[1], tmp[2], tmp[3])
@classmethod
+ def min(cls, lhs, rhs):
+ return Shape4D(
+ min(lhs.batch, rhs.batch), min(lhs.height, rhs.height), min(lhs.width, rhs.width), min(lhs.depth, rhs.depth)
+ )
+
+ @classmethod
+ def max(cls, lhs, rhs):
+ return Shape4D(
+ max(lhs.batch, rhs.batch), max(lhs.height, rhs.height), max(lhs.width, rhs.width), max(lhs.depth, rhs.depth)
+ )
+
+ @classmethod
+ def round_up(cls, lhs, rhs):
+ return Shape4D(
+ round_up(lhs.batch, rhs.batch),
+ round_up(lhs.height, rhs.height),
+ round_up(lhs.width, rhs.width),
+ round_up(lhs.depth, rhs.depth),
+ )
+
+ @classmethod
def from_hwc(cls, h, w, c):
return cls(1, h, w, c)
@@ -60,6 +83,25 @@ class Shape4D(namedtuple("Shape4D", ["batch", "height", "width", "depth"])):
def with_depth(self, new_depth):
return Shape4D(self.batch, self.height, self.width, new_depth)
+ def with_axis(self, axis, new_val):
+ shape_as_list = self.as_list()
+ shape_as_list[axis] = new_val
+ return Shape4D.from_list(shape_as_list)
+
+ @staticmethod
+ def _clip_len(pos, length, size):
+ if pos < 0:
+ length = length + pos
+ pos = 0
+ return min(pos + length, size) - pos
+
+ def clip(self, offset, sub_shape):
+ n = Shape4D._clip_len(offset.batch, sub_shape.batch, self.batch)
+ h = Shape4D._clip_len(offset.height, sub_shape.height, self.height)
+ w = Shape4D._clip_len(offset.width, sub_shape.width, self.width)
+ c = Shape4D._clip_len(offset.depth, sub_shape.depth, self.depth)
+ return Shape4D(n, h, w, c)
+
def add(self, n, h, w, c):
return Shape4D(self.batch + n, self.height + h, self.width + w, self.depth + c)
@@ -74,6 +116,9 @@ class Shape4D(namedtuple("Shape4D", ["batch", "height", "width", "depth"])):
self.batch // rhs.batch, self.height // rhs.height, self.width // rhs.width, self.depth // rhs.depth
)
+ def __truediv__(self, rhs):
+ return Shape4D(self.batch / rhs.batch, self.height / rhs.height, self.width / rhs.width, self.depth / rhs.depth)
+
def __mod__(self, rhs):
return Shape4D(self.batch % rhs.batch, self.height % rhs.height, self.width % rhs.width, self.depth % rhs.depth)
@@ -102,3 +147,52 @@ class Shape4D(namedtuple("Shape4D", ["batch", "height", "width", "depth"])):
def get_hw_as_list(self):
return list([self.height, self.width])
+
+
+class VolumeIterator:
+ """
+ 4D Volume iterator. Use to traverse 4D tensor volumes in smaller shapes.
+ """
+
+ class Direction(Enum):
+ CWHN = 0
+
+ def __init__(
+ self,
+ shape: Shape4D,
+ sub_shape: Shape4D,
+ start: Shape4D = Shape4D(0, 0, 0, 0),
+ delta: Shape4D = None,
+ dir=Direction.CWHN,
+ ):
+ self.b = start.batch
+ self.y = start.height
+ self.x = start.width
+ self.z = start.depth
+ self.shape = shape
+ self.sub_shape = sub_shape
+ self.delta = sub_shape if delta is None else delta
+ assert self.delta.elements() > 0, "Iterator will not move"
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.b >= self.shape.batch:
+ raise StopIteration()
+
+ offset = Shape4D(self.b, self.y, self.x, self.z)
+
+ # CWHN
+ self.z += self.delta.depth
+ if self.z >= self.shape.depth:
+ self.z = 0
+ self.x += self.delta.width
+ if self.x >= self.shape.width:
+ self.x = 0
+ self.y += self.delta.height
+ if self.y >= self.shape.height:
+ self.y = 0
+ self.b += self.delta.batch
+
+ return offset, self.shape.clip(offset, self.sub_shape)