diff options
Diffstat (limited to 'ethosu/vela/high_level_command_stream.py')
-rw-r--r-- | ethosu/vela/high_level_command_stream.py | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py index 0009f6cf..4a41edd0 100644 --- a/ethosu/vela/high_level_command_stream.py +++ b/ethosu/vela/high_level_command_stream.py @@ -34,6 +34,18 @@ class Box: for i in range(len(self.start_coord)): assert self.start_coord[i] <= self.end_coord[i] + @staticmethod + def wrap(a, b): + """Wrap broadcasted tensor boxes in order to + prevent out of bounds during box creation""" + tmp = [0, 0, 0, 0] + for i, val in enumerate(a): + if int(val) != 0: + tmp[i] = a[i] + if a[i] >= b[i] and b[i] != 0: + tmp[i] = a[i] % b[i] + return Shape4D(tmp) + def transform_with_strides_and_skirt( self, strides: List[int], @@ -45,6 +57,7 @@ class Box: split_offset: Optional[Shape4D] = None, split_shape: Optional[Shape4D] = None, upscaling_factor: int = 1, + op_type=None, ): new_start_coord = list(self.start_coord) new_end_coord = list(self.end_coord) @@ -115,6 +128,15 @@ class Box: new_end_coord[-3] = new_end_coord[-3] * stride + skirt[2] + (skirt[2] % upscaling_factor) new_end_coord[-3] = max(min(new_end_coord[-3] // upscaling_factor, ifm_shape.height), 1) + # Wrap the IFMs of broadcasted binary elementwise ops + # at the limits of the non-broadcasted volumes + # Non-broadcasted ops aren't affected by the wrapping + if op_type is not None and op_type.is_binary_elementwise_op(): + tmp = list(ifm_shape) + one = Shape4D(1, 1, 1, 1) + new_start_coord = Box.wrap(new_start_coord, tmp) + new_end_coord = Box.wrap(Shape4D(list(new_end_coord)) - one, tmp) + one + return Box(new_start_coord, new_end_coord), pad_top, pad_bottom def make_weight_box(weight_shape, npu_block_type, oc_range_start=None, oc_range_end=None, weights_transposed=False): |