diff options
Diffstat (limited to 'ethosu/vela/tensor.py')
-rw-r--r-- | ethosu/vela/tensor.py | 105 |
1 files changed, 59 insertions, 46 deletions
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index fb877ca8..ef8a28fc 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -314,26 +314,6 @@ def create_const_tensor( return const_tensor -def create_reshape_tensor(tens, shape, ifm_reshape=True): - if shape == tens.shape: - return tens - # Tensors - name = tens.name + "_reshape" - reshape_ifm = tens - reshape_ofm = tens.clone("_reshaped") - reshape_ofm.set_all_shapes(shape) - if not ifm_reshape: - reshape_ifm, reshape_ofm = reshape_ofm, reshape_ifm - # Operator - reshape_op = Operation(Op.Reshape, name) - reshape_op.attrs["new_shape"] = shape - reshape_op.add_input_tensor(reshape_ifm) - reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, shape)) - reshape_op.set_output_tensor(reshape_ofm) - reshape_op.set_ifm_ofm_shapes() - return reshape_ofm if ifm_reshape else reshape_ifm - - # class that keeps track of all tensor addresses in the different memory types class TensorAddressMap: address_map: Dict = defaultdict(dict) # dict (tens.equivalence_id -> dict (mem_type -> address)) @@ -443,6 +423,10 @@ class Tensor: def address(self, address: int): TensorAddressMap.set_address_for_tens(self.equivalence_id, self.mem_type, address) + @property + def is_standard_fm(self) -> bool: + return self.sub_purpose == TensorSubPurpose.Standard and self.purpose == TensorPurpose.FeatureMap + def element_size(self) -> int: if self.element_size_bytes == 0: return self.dtype.size_in_bits() / 8 @@ -540,6 +524,15 @@ class Tensor: rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment) return rounded_size + def storage_size_for_shape(self, op_storage_shape: Shape) -> int: + elems = shape_num_elements(op_storage_shape) + elems = elems if elems else 0 + raw_size = elems * self.element_size() + if raw_size == 0: + raw_size = 1 # force it to take up space + rounded_size = numeric_util.round_up(numeric_util.round_up_to_int(raw_size), self.alignment) + return rounded_size + def storage_size_for_sub_purpose( self, arch, sub_purpose: TensorSubPurpose, param_a: Optional[int] = None, param_b: Optional[int] = None ) -> int: @@ -614,7 +607,11 @@ class Tensor: def consumers(self) -> List[Operation]: return self.consumer_list - def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, fm_shape: Shape4D) -> Tuple: + def get_4D_storage_shape_for_shape(self, op_shape4D: Shape4D) -> Shape4D: + rounding_quantum = full_shape(4, list(self.storage_rounding_quantum), 1) + return Shape4D(shape_round_to_quantum(op_shape4D.as_list(), rounding_quantum)) + + def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, op_shape4D: Shape4D) -> Tuple: # returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] ) if self.storage_shape == []: @@ -622,12 +619,16 @@ class Tensor: 1, 1, 1, - [self.address_for_coordinate(start_coord, shape=fm_shape.as_list()), None, None, None], + [self.address_for_coordinate(start_coord, op_shape4D=op_shape4D), None, None, None], ) - storage_shape_4D = full_shape(4, self.storage_shape, 1) - crossing_y = numeric_util.round_up(start_coord[1] + 1, storage_shape_4D[1]) - crossing_x = numeric_util.round_up(start_coord[2] + 1, storage_shape_4D[2]) + if self.is_standard_fm: + storage_shape_4D = self.get_4D_storage_shape_for_shape(op_shape4D) + else: + storage_shape_4D = Shape4D(self.storage_shape) + + crossing_y = numeric_util.round_up(start_coord[1] + 1, storage_shape_4D.height) + crossing_x = numeric_util.round_up(start_coord[2] + 1, storage_shape_4D.width) crossing_y = min(crossing_y, end_coord[1]) crossing_x = min(crossing_x, end_coord[2]) @@ -636,39 +637,41 @@ class Tensor: box_width = crossing_x - start_coord[2] addresses: List = [None] * 4 - addresses[0] = self.address_for_coordinate(start_coord, shape=fm_shape.as_list()) + addresses[0] = self.address_for_coordinate(start_coord, op_shape4D=op_shape4D) if end_coord[2] > crossing_x: addresses[1] = self.address_for_coordinate( - [start_coord[0], start_coord[1], crossing_x, start_coord[3]], shape=fm_shape.as_list() + [start_coord[0], start_coord[1], crossing_x, start_coord[3]], op_shape4D=op_shape4D ) raise UnsupportedFeatureError("Striping in vertical direction is not supported") if end_coord[1] > crossing_y: addresses[2] = self.address_for_coordinate( - [start_coord[0], crossing_y, start_coord[2], start_coord[3]], shape=fm_shape.as_list() + [start_coord[0], crossing_y, start_coord[2], start_coord[3]], op_shape4D=op_shape4D ) if end_coord[1] > crossing_y and end_coord[2] > crossing_x: addresses[3] = self.address_for_coordinate( - [start_coord[0], crossing_y, crossing_x, start_coord[3]], shape=fm_shape.as_list() + [start_coord[0], crossing_y, crossing_x, start_coord[3]], op_shape4D=op_shape4D ) return box_height0, box_height0, box_width, addresses - def address_for_coordinate(self, coord: Shape, is_top_box: bool = False, shape: Shape = None) -> int: - if shape is None: - shape = self.shape - offset = self.address_offset_for_coordinate(coord, shape, is_top_box) + def address_for_coordinate(self, coord: Shape, is_top_box: bool = False, op_shape4D: Shape4D = None) -> int: + offset = self.address_offset_for_coordinate(coord, op_shape4D=op_shape4D, is_top_box=is_top_box) assert offset is not None return self.address + offset - def get_strides_and_coord(self, coord: Optional[Shape] = None) -> Tuple[Optional[Shape], Optional[Shape]]: + def get_strides_and_coord( + self, coord: Optional[Shape] = None, shape4D: Optional[Shape4D] = None + ) -> Tuple[Optional[Shape], Optional[Shape]]: if coord is None: coord = [0] * len(self.storage_shape) + if shape4D and self.is_standard_fm: + augmented_shape = self.get_4D_storage_shape_for_shape(shape4D).as_list() + else: + augmented_shape = full_shape(4, self.storage_shape, 1) + augmented_coord = coord - augmented_shape = self.storage_shape - while len(augmented_shape) < 4: - augmented_shape = [1] + augmented_shape while len(augmented_coord) < 4: augmented_coord = [0] + augmented_coord @@ -713,8 +716,8 @@ class Tensor: return strides, augmented_coord - def get_strides(self) -> Shape: - strides, _ = self.get_strides_and_coord() + def get_strides(self, shape4D: Optional[Shape4D] = None) -> Shape: + strides, _ = self.get_strides_and_coord(shape4D=shape4D) assert strides is not None return strides @@ -769,13 +772,13 @@ class Tensor: assert 0 <= index < len(self.compressed_values) return index == len(self.compressed_values) - 1 - def address_offset_for_coordinate(self, orig_coord: Shape, shape: Shape, is_top_box: bool = False) -> Optional[int]: + def address_offset_for_coordinate( + self, orig_coord: Shape, op_shape4D: Optional[Shape4D] = None, is_top_box: bool = False + ) -> Optional[int]: address_offset = 0 - coord = orig_coord - - coord = coord[-len(self.storage_shape) :] if self.sub_purpose == TensorSubPurpose.Standard: + shape = op_shape4D.as_list() if op_shape4D else self.shape for idx, c in enumerate(orig_coord): if is_top_box: assert c > 0 and c <= shape[idx] @@ -783,6 +786,7 @@ class Tensor: assert c >= 0 and c < shape[idx] if self.format == TensorFormat.WeightsCompressed: + storage_size = self.storage_size() if len(self.weight_compressed_offsets) == 0: return 0 @@ -814,13 +818,22 @@ class Tensor: assert index < len(self.weight_compressed_offsets) address_offset = self.weight_compressed_offsets[index] else: + coord = orig_coord + if op_shape4D and self.is_standard_fm: + storage_shape = self.get_4D_storage_shape_for_shape(op_shape4D).as_list() + storage_size = self.storage_size_for_shape(storage_shape) + else: + storage_shape = self.storage_shape + coord = coord[-len(storage_shape) :] + storage_size = self.storage_size() + if is_top_box: coord = [c - 1 for c in coord] # handle wraparound for partial buffers. make sure to do this after subtracting top box: - coord = [c % self.storage_shape[idx] for idx, c in enumerate(coord)] + coord = [c % storage_shape[idx] for idx, c in enumerate(coord)] - strides, augmented_coord = self.get_strides_and_coord(coord) + strides, augmented_coord = self.get_strides_and_coord(coord, op_shape4D) if strides is None: return None @@ -830,7 +843,7 @@ class Tensor: address_offset += np.dot(augmented_coord, strides) assert address_offset >= 0 - assert address_offset <= self.storage_size() + assert address_offset <= storage_size return address_offset def is_allocated_in_tensor_arena(self, scratch_tensor_mem_area: MemArea) -> bool: |