From eeb85154b00a9864d0d63e382e9c80ca8e294d5d Mon Sep 17 00:00:00 2001 From: "patrik.gustavsson" Date: Mon, 21 Dec 2020 17:10:40 +0000 Subject: Revert "Revert "MLBEDSW-3645 4D class for op ifm/ofm shapes"" This reverts commit df0a5905177f3a1b836076bc3f9f39b2e86f1794. Reason for revert: Change-Id: I891c66fb29db9d25e942947e8d1c29a10610de51 --- ethosu/vela/tensor.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) (limited to 'ethosu/vela/tensor.py') diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index df8f8868..093e8771 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -40,6 +40,7 @@ from .ethos_u55_regs.ethos_u55_regs import resampling_mode from .numeric_util import full_shape from .operation import Op from .operation import Operation +from .shape4d import Shape4D Shape = List @@ -304,6 +305,7 @@ def create_const_tensor( # Operator const_op = Operation(Op.Const, name) const_op.set_output_tensor(const_tensor) + const_op.set_ifm_ofm_shapes() return const_tensor @@ -323,8 +325,7 @@ def create_reshape_tensor(tens, shape, ifm_reshape=True): reshape_op.add_input_tensor(reshape_ifm) reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, shape)) reshape_op.set_output_tensor(reshape_ofm) - reshape_op.ifm_shapes.append(full_shape(4, reshape_ifm.shape, 1)) - reshape_op.ofm_shapes.append(full_shape(4, reshape_ofm.shape, 1)) + reshape_op.set_ifm_ofm_shapes() return reshape_ofm if ifm_reshape else reshape_ifm @@ -608,7 +609,7 @@ class Tensor: def consumers(self) -> List[Operation]: return self.consumer_list - def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, fm_shape: Shape) -> Tuple: + def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, fm_shape: Shape4D) -> Tuple: # returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] ) if self.storage_shape == []: @@ -616,7 +617,7 @@ class Tensor: 1, 1, 1, - [self.address_for_coordinate(start_coord, shape=fm_shape), None, None, None], + [self.address_for_coordinate(start_coord, shape=fm_shape.as_list()), None, None, None], ) storage_shape_4D = full_shape(4, self.storage_shape, 1) @@ -630,20 +631,20 @@ class Tensor: box_width = crossing_x - start_coord[2] addresses: List = [None] * 4 - addresses[0] = self.address_for_coordinate(start_coord, shape=fm_shape) + addresses[0] = self.address_for_coordinate(start_coord, shape=fm_shape.as_list()) if end_coord[2] > crossing_x: addresses[1] = self.address_for_coordinate( - [start_coord[0], start_coord[1], crossing_x, start_coord[3]], shape=fm_shape + [start_coord[0], start_coord[1], crossing_x, start_coord[3]], shape=fm_shape.as_list() ) raise UnsupportedFeatureError("Striping in vertical direction is not supported") if end_coord[1] > crossing_y: addresses[2] = self.address_for_coordinate( - [start_coord[0], crossing_y, start_coord[2], start_coord[3]], shape=fm_shape + [start_coord[0], crossing_y, start_coord[2], start_coord[3]], shape=fm_shape.as_list() ) if end_coord[1] > crossing_y and end_coord[2] > crossing_x: addresses[3] = self.address_for_coordinate( - [start_coord[0], crossing_y, crossing_x, start_coord[3]], shape=fm_shape + [start_coord[0], crossing_y, crossing_x, start_coord[3]], shape=fm_shape.as_list() ) return box_height0, box_height0, box_width, addresses -- cgit v1.2.1