aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tensor.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r--ethosu/vela/tensor.py105
1 files changed, 59 insertions, 46 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index fb877ca8..ef8a28fc 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -314,26 +314,6 @@ def create_const_tensor(
return const_tensor
-def create_reshape_tensor(tens, shape, ifm_reshape=True):
- if shape == tens.shape:
- return tens
- # Tensors
- name = tens.name + "_reshape"
- reshape_ifm = tens
- reshape_ofm = tens.clone("_reshaped")
- reshape_ofm.set_all_shapes(shape)
- if not ifm_reshape:
- reshape_ifm, reshape_ofm = reshape_ofm, reshape_ifm
- # Operator
- reshape_op = Operation(Op.Reshape, name)
- reshape_op.attrs["new_shape"] = shape
- reshape_op.add_input_tensor(reshape_ifm)
- reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, shape))
- reshape_op.set_output_tensor(reshape_ofm)
- reshape_op.set_ifm_ofm_shapes()
- return reshape_ofm if ifm_reshape else reshape_ifm
-
-
# class that keeps track of all tensor addresses in the different memory types
class TensorAddressMap:
address_map: Dict = defaultdict(dict) # dict (tens.equivalence_id -> dict (mem_type -> address))
@@ -443,6 +423,10 @@ class Tensor:
def address(self, address: int):
TensorAddressMap.set_address_for_tens(self.equivalence_id, self.mem_type, address)
+ @property
+ def is_standard_fm(self) -> bool:
+ return self.sub_purpose == TensorSubPurpose.Standard and self.purpose == TensorPurpose.FeatureMap
+
def element_size(self) -> int:
if self.element_size_bytes == 0:
return self.dtype.size_in_bits() / 8
@@ -540,6 +524,15 @@ class Tensor:
rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
return rounded_size
+ def storage_size_for_shape(self, op_storage_shape: Shape) -> int:
+ elems = shape_num_elements(op_storage_shape)
+ elems = elems if elems else 0
+ raw_size = elems * self.element_size()
+ if raw_size == 0:
+ raw_size = 1 # force it to take up space
+ rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment)
+ return rounded_size
+
def storage_size_for_sub_purpose(
self, arch, sub_purpose: TensorSubPurpose, param_a: Optional[int] = None, param_b: Optional[int] = None
) -> int:
@@ -614,7 +607,11 @@ class Tensor:
def consumers(self) -> List[Operation]:
return self.consumer_list
- def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, fm_shape: Shape4D) -> Tuple:
+ def get_4D_storage_shape_for_shape(self, op_shape4D: Shape4D) -> Shape4D:
+ rounding_quantum = full_shape(4, list(self.storage_rounding_quantum), 1)
+ return Shape4D(shape_round_to_quantum(op_shape4D.as_list(), rounding_quantum))
+
+ def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, op_shape4D: Shape4D) -> Tuple:
# returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] )
if self.storage_shape == []:
@@ -622,12 +619,16 @@ class Tensor:
1,
1,
1,
- [self.address_for_coordinate(start_coord, shape=fm_shape.as_list()), None, None, None],
+ [self.address_for_coordinate(start_coord, op_shape4D=op_shape4D), None, None, None],
)
- storage_shape_4D = full_shape(4, self.storage_shape, 1)
- crossing_y = numeric_util.round_up(start_coord[1] + 1, storage_shape_4D[1])
- crossing_x = numeric_util.round_up(start_coord[2] + 1, storage_shape_4D[2])
+ if self.is_standard_fm:
+ storage_shape_4D = self.get_4D_storage_shape_for_shape(op_shape4D)
+ else:
+ storage_shape_4D = Shape4D(self.storage_shape)
+
+ crossing_y = numeric_util.round_up(start_coord[1] + 1, storage_shape_4D.height)
+ crossing_x = numeric_util.round_up(start_coord[2] + 1, storage_shape_4D.width)
crossing_y = min(crossing_y, end_coord[1])
crossing_x = min(crossing_x, end_coord[2])
@@ -636,39 +637,41 @@ class Tensor:
box_width = crossing_x - start_coord[2]
addresses: List = [None] * 4
- addresses[0] = self.address_for_coordinate(start_coord, shape=fm_shape.as_list())
+ addresses[0] = self.address_for_coordinate(start_coord, op_shape4D=op_shape4D)
if end_coord[2] > crossing_x:
addresses[1] = self.address_for_coordinate(
- [start_coord[0], start_coord[1], crossing_x, start_coord[3]], shape=fm_shape.as_list()
+ [start_coord[0], start_coord[1], crossing_x, start_coord[3]], op_shape4D=op_shape4D
)
raise UnsupportedFeatureError("Striping in vertical direction is not supported")
if end_coord[1] > crossing_y:
addresses[2] = self.address_for_coordinate(
- [start_coord[0], crossing_y, start_coord[2], start_coord[3]], shape=fm_shape.as_list()
+ [start_coord[0], crossing_y, start_coord[2], start_coord[3]], op_shape4D=op_shape4D
)
if end_coord[1] > crossing_y and end_coord[2] > crossing_x:
addresses[3] = self.address_for_coordinate(
- [start_coord[0], crossing_y, crossing_x, start_coord[3]], shape=fm_shape.as_list()
+ [start_coord[0], crossing_y, crossing_x, start_coord[3]], op_shape4D=op_shape4D
)
return box_height0, box_height0, box_width, addresses
- def address_for_coordinate(self, coord: Shape, is_top_box: bool = False, shape: Shape = None) -> int:
- if shape is None:
- shape = self.shape
- offset = self.address_offset_for_coordinate(coord, shape, is_top_box)
+ def address_for_coordinate(self, coord: Shape, is_top_box: bool = False, op_shape4D: Shape4D = None) -> int:
+ offset = self.address_offset_for_coordinate(coord, op_shape4D=op_shape4D, is_top_box=is_top_box)
assert offset is not None
return self.address + offset
- def get_strides_and_coord(self, coord: Optional[Shape] = None) -> Tuple[Optional[Shape], Optional[Shape]]:
+ def get_strides_and_coord(
+ self, coord: Optional[Shape] = None, shape4D: Optional[Shape4D] = None
+ ) -> Tuple[Optional[Shape], Optional[Shape]]:
if coord is None:
coord = [0] * len(self.storage_shape)
+ if shape4D and self.is_standard_fm:
+ augmented_shape = self.get_4D_storage_shape_for_shape(shape4D).as_list()
+ else:
+ augmented_shape = full_shape(4, self.storage_shape, 1)
+
augmented_coord = coord
- augmented_shape = self.storage_shape
- while len(augmented_shape) < 4:
- augmented_shape = [1] + augmented_shape
while len(augmented_coord) < 4:
augmented_coord = [0] + augmented_coord
@@ -713,8 +716,8 @@ class Tensor:
return strides, augmented_coord
- def get_strides(self) -> Shape:
- strides, _ = self.get_strides_and_coord()
+ def get_strides(self, shape4D: Optional[Shape4D] = None) -> Shape:
+ strides, _ = self.get_strides_and_coord(shape4D=shape4D)
assert strides is not None
return strides
@@ -769,13 +772,13 @@ class Tensor:
assert 0 <= index < len(self.compressed_values)
return index == len(self.compressed_values) - 1
- def address_offset_for_coordinate(self, orig_coord: Shape, shape: Shape, is_top_box: bool = False) -> Optional[int]:
+ def address_offset_for_coordinate(
+ self, orig_coord: Shape, op_shape4D: Optional[Shape4D] = None, is_top_box: bool = False
+ ) -> Optional[int]:
address_offset = 0
- coord = orig_coord
-
- coord = coord[-len(self.storage_shape) :]
if self.sub_purpose == TensorSubPurpose.Standard:
+ shape = op_shape4D.as_list() if op_shape4D else self.shape
for idx, c in enumerate(orig_coord):
if is_top_box:
assert c > 0 and c <= shape[idx]
@@ -783,6 +786,7 @@ class Tensor:
assert c >= 0 and c < shape[idx]
if self.format == TensorFormat.WeightsCompressed:
+ storage_size = self.storage_size()
if len(self.weight_compressed_offsets) == 0:
return 0
@@ -814,13 +818,22 @@ class Tensor:
assert index < len(self.weight_compressed_offsets)
address_offset = self.weight_compressed_offsets[index]
else:
+ coord = orig_coord
+ if op_shape4D and self.is_standard_fm:
+ storage_shape = self.get_4D_storage_shape_for_shape(op_shape4D).as_list()
+ storage_size = self.storage_size_for_shape(storage_shape)
+ else:
+ storage_shape = self.storage_shape
+ coord = coord[-len(storage_shape) :]
+ storage_size = self.storage_size()
+
if is_top_box:
coord = [c - 1 for c in coord]
# handle wraparound for partial buffers. make sure to do this after subtracting top box:
- coord = [c % self.storage_shape[idx] for idx, c in enumerate(coord)]
+ coord = [c % storage_shape[idx] for idx, c in enumerate(coord)]
- strides, augmented_coord = self.get_strides_and_coord(coord)
+ strides, augmented_coord = self.get_strides_and_coord(coord, op_shape4D)
if strides is None:
return None
@@ -830,7 +843,7 @@ class Tensor:
address_offset += np.dot(augmented_coord, strides)
assert address_offset >= 0
- assert address_offset <= self.storage_size()
+ assert address_offset <= storage_size
return address_offset
def is_allocated_in_tensor_arena(self, scratch_tensor_mem_area: MemArea) -> bool: