aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/high_level_command_stream.py
diff options
context:
space:
mode:
authorerik.andersson@arm.com <erik.andersson@arm.com>2022-03-22 15:35:30 +0100
committererik.andersson@arm.com <erik.andersson@arm.com>2022-07-11 11:27:47 +0200
commit6b2a0b4a64d01c8b038050a87c29f38a4909515c (patch)
tree0bd213a78debbfbe8465fcbf1c87eadd1f44fc2f /ethosu/vela/high_level_command_stream.py
parent25f48dd70aebeecd490de71eed3d4f7fbad1b121 (diff)
downloadethos-u-vela-6b2a0b4a64d01c8b038050a87c29f38a4909515c.tar.gz
MLBEDSW-6261: Elementwise cascading
Enabled elementwise cascading for binary/single variable IFM operators. Signed-off-by: erik.andersson@arm.com <erik.andersson@arm.com> Change-Id: I1c0867875fdc5c4980224fb570185c11e719d5cd
Diffstat (limited to 'ethosu/vela/high_level_command_stream.py')
-rw-r--r--ethosu/vela/high_level_command_stream.py22
1 files changed, 22 insertions, 0 deletions
diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py
index 0009f6cf..4a41edd0 100644
--- a/ethosu/vela/high_level_command_stream.py
+++ b/ethosu/vela/high_level_command_stream.py
@@ -34,6 +34,18 @@ class Box:
for i in range(len(self.start_coord)):
assert self.start_coord[i] <= self.end_coord[i]
+ @staticmethod
+ def wrap(a, b):
+ """Wrap broadcasted tensor boxes in order to
+ prevent out of bounds during box creation"""
+ tmp = [0, 0, 0, 0]
+ for i, val in enumerate(a):
+ if int(val) != 0:
+ tmp[i] = a[i]
+ if a[i] >= b[i] and b[i] != 0:
+ tmp[i] = a[i] % b[i]
+ return Shape4D(tmp)
+
def transform_with_strides_and_skirt(
self,
strides: List[int],
@@ -45,6 +57,7 @@ class Box:
split_offset: Optional[Shape4D] = None,
split_shape: Optional[Shape4D] = None,
upscaling_factor: int = 1,
+ op_type=None,
):
new_start_coord = list(self.start_coord)
new_end_coord = list(self.end_coord)
@@ -115,6 +128,15 @@ class Box:
new_end_coord[-3] = new_end_coord[-3] * stride + skirt[2] + (skirt[2] % upscaling_factor)
new_end_coord[-3] = max(min(new_end_coord[-3] // upscaling_factor, ifm_shape.height), 1)
+ # Wrap the IFMs of broadcasted binary elementwise ops
+ # at the limits of the non-broadcasted volumes
+ # Non-broadcasted ops aren't affected by the wrapping
+ if op_type is not None and op_type.is_binary_elementwise_op():
+ tmp = list(ifm_shape)
+ one = Shape4D(1, 1, 1, 1)
+ new_start_coord = Box.wrap(new_start_coord, tmp)
+ new_end_coord = Box.wrap(Shape4D(list(new_end_coord)) - one, tmp) + one
+
return Box(new_start_coord, new_end_coord), pad_top, pad_bottom
def make_weight_box(weight_shape, npu_block_type, oc_range_start=None, oc_range_end=None, weights_transposed=False):