aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tensor.py
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2020-12-01 16:02:29 +0100
committerPatrik Gustavsson <patrik.gustavsson@arm.com>2020-12-18 16:33:32 +0100
commit2349d429d926e258e9a61d34c7fd97660ab9fb98 (patch)
treeb5151d0f12428e47d64b1fb2ce4f2f8c19304a0d /ethosu/vela/tensor.py
parent528a56df829b65f7a2c61953650b123c461095f7 (diff)
downloadethos-u-vela-2349d429d926e258e9a61d34c7fd97660ab9fb98.tar.gz
MLBEDSW-3654 Add/use op ifm/ofm shapes
Add ifm/ofm shapes to op Changed to rely on these shapes Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: I571535a1dcadc2bdb04a3c727a8e1c49703b174d
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r--ethosu/vela/tensor.py55
1 files changed, 33 insertions, 22 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 69618d2c..df8f8868 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -37,6 +37,7 @@ from .data_type import DataType
from .errors import UnsupportedFeatureError
from .errors import VelaError
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
+from .numeric_util import full_shape
from .operation import Op
from .operation import Operation
@@ -322,6 +323,8 @@ def create_reshape_tensor(tens, shape, ifm_reshape=True):
reshape_op.add_input_tensor(reshape_ifm)
reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, shape))
reshape_op.set_output_tensor(reshape_ofm)
+ reshape_op.ifm_shapes.append(full_shape(4, reshape_ifm.shape, 1))
+ reshape_op.ofm_shapes.append(full_shape(4, reshape_ofm.shape, 1))
return reshape_ofm if ifm_reshape else reshape_ifm
@@ -605,20 +608,20 @@ class Tensor:
def consumers(self) -> List[Operation]:
return self.consumer_list
- def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape) -> Tuple:
+ def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, fm_shape: Shape) -> Tuple:
# returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] )
- if len(start_coord) < 4:
- box_height0 = 1
- box_width = 1
-
- if len(start_coord) >= 2:
- box_width = end_coord[-2] - start_coord[-2]
-
- return box_height0, box_height0, box_width, [self.address_for_coordinate(start_coord), None, None, None]
+ if self.storage_shape == []:
+ return (
+ 1,
+ 1,
+ 1,
+ [self.address_for_coordinate(start_coord, shape=fm_shape), None, None, None],
+ )
- crossing_y = numeric_util.round_up(start_coord[1] + 1, self.storage_shape[1])
- crossing_x = numeric_util.round_up(start_coord[2] + 1, self.storage_shape[2])
+ storage_shape_4D = full_shape(4, self.storage_shape, 1)
+ crossing_y = numeric_util.round_up(start_coord[1] + 1, storage_shape_4D[1])
+ crossing_x = numeric_util.round_up(start_coord[2] + 1, storage_shape_4D[2])
crossing_y = min(crossing_y, end_coord[1])
crossing_x = min(crossing_x, end_coord[2])
@@ -627,20 +630,28 @@ class Tensor:
box_width = crossing_x - start_coord[2]
addresses: List = [None] * 4
- addresses[0] = self.address_for_coordinate(start_coord)
+ addresses[0] = self.address_for_coordinate(start_coord, shape=fm_shape)
if end_coord[2] > crossing_x:
- addresses[1] = self.address_for_coordinate([start_coord[0], start_coord[1], crossing_x, start_coord[3]])
+ addresses[1] = self.address_for_coordinate(
+ [start_coord[0], start_coord[1], crossing_x, start_coord[3]], shape=fm_shape
+ )
raise UnsupportedFeatureError("Striping in vertical direction is not supported")
if end_coord[1] > crossing_y:
- addresses[2] = self.address_for_coordinate([start_coord[0], crossing_y, start_coord[2], start_coord[3]])
+ addresses[2] = self.address_for_coordinate(
+ [start_coord[0], crossing_y, start_coord[2], start_coord[3]], shape=fm_shape
+ )
if end_coord[1] > crossing_y and end_coord[2] > crossing_x:
- addresses[3] = self.address_for_coordinate([start_coord[0], crossing_y, crossing_x, start_coord[3]])
+ addresses[3] = self.address_for_coordinate(
+ [start_coord[0], crossing_y, crossing_x, start_coord[3]], shape=fm_shape
+ )
return box_height0, box_height0, box_width, addresses
- def address_for_coordinate(self, coord: Shape, is_top_box: bool = False) -> int:
- offset = self.address_offset_for_coordinate(coord, is_top_box)
+ def address_for_coordinate(self, coord: Shape, is_top_box: bool = False, shape: Shape = None) -> int:
+ if shape is None:
+ shape = self.shape
+ offset = self.address_offset_for_coordinate(coord, shape, is_top_box)
assert offset is not None
return self.address + offset
@@ -752,18 +763,18 @@ class Tensor:
assert 0 <= index < len(self.compressed_values)
return index == len(self.compressed_values) - 1
- def address_offset_for_coordinate(self, orig_coord: Shape, is_top_box: bool = False) -> Optional[int]:
+ def address_offset_for_coordinate(self, orig_coord: Shape, shape: Shape, is_top_box: bool = False) -> Optional[int]:
address_offset = 0
coord = orig_coord
coord = coord[-len(self.storage_shape) :]
if self.sub_purpose == TensorSubPurpose.Standard:
- for idx, c in enumerate(coord):
+ for idx, c in enumerate(orig_coord):
if is_top_box:
- assert c > 0 and c <= self.shape[idx]
+ assert c > 0 and c <= shape[idx]
else:
- assert c >= 0 and c < self.shape[idx]
+ assert c >= 0 and c < shape[idx]
if self.format == TensorFormat.WeightsCompressed:
if len(self.weight_compressed_offsets) == 0:
@@ -830,7 +841,7 @@ class Tensor:
def get_full_shape(self) -> Shape:
d = len(self.shape)
if d in (1, 3):
- return numeric_util.full_shape(4, self.shape, 1)
+ return full_shape(4, self.shape, 1)
elif d == 2:
return [self.shape[0], 1, 1, self.shape[1]]
else: