diff options
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r-- | ethosu/vela/tensor.py | 17 |
1 files changed, 8 insertions, 9 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index 093e8771..df8f8868 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -40,7 +40,6 @@ from .ethos_u55_regs.ethos_u55_regs import resampling_mode from .numeric_util import full_shape from .operation import Op from .operation import Operation -from .shape4d import Shape4D Shape = List @@ -305,7 +304,6 @@ def create_const_tensor( # Operator const_op = Operation(Op.Const, name) const_op.set_output_tensor(const_tensor) - const_op.set_ifm_ofm_shapes() return const_tensor @@ -325,7 +323,8 @@ def create_reshape_tensor(tens, shape, ifm_reshape=True): reshape_op.add_input_tensor(reshape_ifm) reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, shape)) reshape_op.set_output_tensor(reshape_ofm) - reshape_op.set_ifm_ofm_shapes() + reshape_op.ifm_shapes.append(full_shape(4, reshape_ifm.shape, 1)) + reshape_op.ofm_shapes.append(full_shape(4, reshape_ofm.shape, 1)) return reshape_ofm if ifm_reshape else reshape_ifm @@ -609,7 +608,7 @@ class Tensor: def consumers(self) -> List[Operation]: return self.consumer_list - def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, fm_shape: Shape4D) -> Tuple: + def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, fm_shape: Shape) -> Tuple: # returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] ) if self.storage_shape == []: @@ -617,7 +616,7 @@ class Tensor: 1, 1, 1, - [self.address_for_coordinate(start_coord, shape=fm_shape.as_list()), None, None, None], + [self.address_for_coordinate(start_coord, shape=fm_shape), None, None, None], ) storage_shape_4D = full_shape(4, self.storage_shape, 1) @@ -631,20 +630,20 @@ class Tensor: box_width = crossing_x - start_coord[2] addresses: List = [None] * 4 - addresses[0] = self.address_for_coordinate(start_coord, shape=fm_shape.as_list()) + addresses[0] = self.address_for_coordinate(start_coord, shape=fm_shape) if end_coord[2] > crossing_x: addresses[1] = self.address_for_coordinate( - [start_coord[0], start_coord[1], crossing_x, start_coord[3]], shape=fm_shape.as_list() + [start_coord[0], start_coord[1], crossing_x, start_coord[3]], shape=fm_shape ) raise UnsupportedFeatureError("Striping in vertical direction is not supported") if end_coord[1] > crossing_y: addresses[2] = self.address_for_coordinate( - [start_coord[0], crossing_y, start_coord[2], start_coord[3]], shape=fm_shape.as_list() + [start_coord[0], crossing_y, start_coord[2], start_coord[3]], shape=fm_shape ) if end_coord[1] > crossing_y and end_coord[2] > crossing_x: addresses[3] = self.address_for_coordinate( - [start_coord[0], crossing_y, crossing_x, start_coord[3]], shape=fm_shape.as_list() + [start_coord[0], crossing_y, crossing_x, start_coord[3]], shape=fm_shape ) return box_height0, box_height0, box_width, addresses |