diff options
author | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2020-12-16 13:08:06 +0100 |
---|---|---|
committer | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2020-12-21 07:34:05 +0100 |
commit | bf31d647dc5df47410ee577b12427ddf076d816b (patch) | |
tree | 85ddd620916565aa8565d072b764ca4918b405a1 /ethosu/vela/tensor.py | |
parent | 2349d429d926e258e9a61d34c7fd97660ab9fb98 (diff) | |
download | ethos-u-vela-bf31d647dc5df47410ee577b12427ddf076d816b.tar.gz |
MLBEDSW-3645 4D class for op ifm/ofm shapes
Add 4D shape class for op Ifm/ofm shapes
Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: Ic0a98da9d2f9d085605e39a9ab5a26bad6e702a3
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r-- | ethosu/vela/tensor.py | 17 |
1 files changed, 9 insertions, 8 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index df8f8868..093e8771 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -40,6 +40,7 @@ from .ethos_u55_regs.ethos_u55_regs import resampling_mode from .numeric_util import full_shape from .operation import Op from .operation import Operation +from .shape4d import Shape4D Shape = List @@ -304,6 +305,7 @@ def create_const_tensor( # Operator const_op = Operation(Op.Const, name) const_op.set_output_tensor(const_tensor) + const_op.set_ifm_ofm_shapes() return const_tensor @@ -323,8 +325,7 @@ def create_reshape_tensor(tens, shape, ifm_reshape=True): reshape_op.add_input_tensor(reshape_ifm) reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, shape)) reshape_op.set_output_tensor(reshape_ofm) - reshape_op.ifm_shapes.append(full_shape(4, reshape_ifm.shape, 1)) - reshape_op.ofm_shapes.append(full_shape(4, reshape_ofm.shape, 1)) + reshape_op.set_ifm_ofm_shapes() return reshape_ofm if ifm_reshape else reshape_ifm @@ -608,7 +609,7 @@ class Tensor: def consumers(self) -> List[Operation]: return self.consumer_list - def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, fm_shape: Shape) -> Tuple: + def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, fm_shape: Shape4D) -> Tuple: # returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] ) if self.storage_shape == []: @@ -616,7 +617,7 @@ class Tensor: 1, 1, 1, - [self.address_for_coordinate(start_coord, shape=fm_shape), None, None, None], + [self.address_for_coordinate(start_coord, shape=fm_shape.as_list()), None, None, None], ) storage_shape_4D = full_shape(4, self.storage_shape, 1) @@ -630,20 +631,20 @@ class Tensor: box_width = crossing_x - start_coord[2] addresses: List = [None] * 4 - addresses[0] = self.address_for_coordinate(start_coord, shape=fm_shape) + addresses[0] = self.address_for_coordinate(start_coord, shape=fm_shape.as_list()) if end_coord[2] > crossing_x: addresses[1] = self.address_for_coordinate( - [start_coord[0], start_coord[1], crossing_x, start_coord[3]], shape=fm_shape + [start_coord[0], start_coord[1], crossing_x, start_coord[3]], shape=fm_shape.as_list() ) raise UnsupportedFeatureError("Striping in vertical direction is not supported") if end_coord[1] > crossing_y: addresses[2] = self.address_for_coordinate( - [start_coord[0], crossing_y, start_coord[2], start_coord[3]], shape=fm_shape + [start_coord[0], crossing_y, start_coord[2], start_coord[3]], shape=fm_shape.as_list() ) if end_coord[1] > crossing_y and end_coord[2] > crossing_x: addresses[3] = self.address_for_coordinate( - [start_coord[0], crossing_y, crossing_x, start_coord[3]], shape=fm_shape + [start_coord[0], crossing_y, crossing_x, start_coord[3]], shape=fm_shape.as_list() ) return box_height0, box_height0, box_width, addresses |