From 3f22ec2025c8e1afe6780785fd8c62c015824a63 Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Tue, 21 Sep 2021 14:18:44 +0200 Subject: TOSA: Decompose elem op tensors Added decomposition of tensors exceeding maximum size supported by NPU. Signed-off-by: Patrik Gustavsson Change-Id: I17a99cb72947d2f1064a631ad6975ce895c258d5 --- ethosu/vela/shape4d.py | 3 + ethosu/vela/tosa_graph_optimiser.py | 234 ++++++++++++++++++++------------ ethosu/vela/tosa_supported_operators.py | 21 +-- 3 files changed, 159 insertions(+), 99 deletions(-) diff --git a/ethosu/vela/shape4d.py b/ethosu/vela/shape4d.py index 08b2a6a0..fd1ee949 100644 --- a/ethosu/vela/shape4d.py +++ b/ethosu/vela/shape4d.py @@ -111,6 +111,9 @@ class Shape4D(namedtuple("Shape4D", ["batch", "height", "width", "depth"])): def __sub__(self, rhs): return Shape4D(self.batch - rhs.batch, self.height - rhs.height, self.width - rhs.width, self.depth - rhs.depth) + def floordiv_const(self, const): + return Shape4D(self.batch // const, self.height // const, self.width // const, self.depth // const) + def __floordiv__(self, rhs): return Shape4D( self.batch // rhs.batch, self.height // rhs.height, self.width // rhs.width, self.depth // rhs.depth diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py index f4aa4534..1e059cc4 100644 --- a/ethosu/vela/tosa_graph_optimiser.py +++ b/ethosu/vela/tosa_graph_optimiser.py @@ -503,21 +503,71 @@ def convert_table_to_lut(op, arch, nng): return convert_to_lut(op, table.values, "table") -def create_elem_part_op(op, ifm_offset_n, ifm2_offset_n, ofm_offset_n): - part_op = op.clone() - offset = Shape4D(0, 0, 0, 0) +def decompose_tensors_hwc(op): + max_t_size = 65535 + ofm_shape = op.ofm_shapes[0] + ifm_shape = op.ifm_shapes[0] + ifm2_shape = op.ifm_shapes[1] if op.ifm_shapes[1] else None + + limit_shape = Shape4D(1, max_t_size, max_t_size, max_t_size) + + if any(dim_size > max_t_size for dim_size in ofm_shape.as_list()): + ofm_split = ofm_shape.floordiv_const(max_t_size).add(1, 1, 1, 1) - part_op.read_offsets[0] = offset.with_batch(ifm_offset_n) - part_op.read_shapes[0] = op.ifm_shapes[0].with_batch(1) - part_op.write_offset = offset.with_batch(ofm_offset_n) - part_op.write_shape = op.ofm_shapes[0].with_batch(1) + for height in range(ofm_split.height): + for width in range(ofm_split.width): + for depth in range(ofm_split.depth): + ofm_offset = Shape4D(0, height * max_t_size, width * max_t_size, depth * max_t_size) + ofm_part_shape = ofm_shape.clip(ofm_offset, limit_shape) + ofm_cut = (ofm_offset, ofm_part_shape) + + ifm_d = depth * max_t_size if ifm_shape.depth == ofm_shape.depth else 0 + ifm_w = width * max_t_size if ifm_shape.width == ofm_shape.width else 0 + ifm_h = height * max_t_size if ifm_shape.height == ofm_shape.height else 0 + ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d) + ifm_part_shape = ifm_shape.clip(ifm_offset, limit_shape) + ifm_cut = (ifm_offset, ifm_part_shape) + + if ifm2_shape is not None: + ifm2_d = depth * max_t_size if ifm2_shape.depth == ofm_shape.depth else 0 + ifm2_w = width * max_t_size if ifm2_shape.width == ofm_shape.width else 0 + ifm2_h = height * max_t_size if ifm2_shape.height == ofm_shape.height else 0 + ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d) + ifm2_part_shape = ifm2_shape.clip(ifm2_offset, limit_shape) + ifm2_cut = (ifm2_offset, ifm2_part_shape) + else: + ifm2_offset = None + ifm2_cut = (None, None) + + create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut) + op.ofm.ops.remove(op) + op.ifm.consumer_list.remove(op) + if op.ifm2 is not None: + op.ifm2.consumer_list.remove(op) + return + + +def create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut): + part_op = op.clone() + ifm_read_offset = op.read_offsets[0] if op.read_offsets[0] is not None else Shape4D(0, 0, 0, 0) + ofm_write_offset = op.write_offset if op.write_offset is not None else Shape4D(0, 0, 0, 0) + ifm_offset, ifm_shape = ifm_cut + ofm_offset, ofm_shape = ofm_cut + + part_op.read_offsets[0] = ifm_read_offset + ifm_offset + part_op.read_shapes[0] = ifm_shape + part_op.write_offset = ofm_write_offset + ofm_offset + part_op.write_shape = ofm_shape part_op.ifm_shapes = op.ifm_shapes.copy() part_op.ofm_shapes = op.ofm_shapes.copy() part_op.ifm.consumer_list.append(part_op) op.ofm.ops.append(part_op) - if ifm2_offset_n: - part_op.read_offsets[1] = offset.with_batch(ifm2_offset_n) - part_op.read_shapes[1] = op.ifm_shapes[1].with_batch(1) + + ifm2_offset, ifm2_shape = ifm2_cut + if ifm2_offset: + ifm2_read_offset = op.read_offsets[1] if op.read_offsets[1] is not None else Shape4D(0, 0, 0, 0) + part_op.read_offsets[1] = ifm2_read_offset + ifm2_offset + part_op.read_shapes[1] = ifm2_shape part_op.ifm2.consumer_list.append(part_op) @@ -528,114 +578,120 @@ def get_nhwc_stride(shape): return Shape4D(stride_n, stride_y, stride_x, 1) -def decomp_unary_elementwise(op): +def get_elem_shapes_removed_singles(op): """ - Decompose binary elementwise ops with Rank > 3 (H,W,D). - If Rank > 3, all the dimensions above H are viewed as the N dimension. - the elementwise operation will be decomposed to N (of ofm) elementwise operations. - By reading and writing with offsets from/to the ifm/ofm. + Returns the shapes of ifm(s)/ofms after removing all the dimensions that are 1 for all ifm(s)/ofm """ - ifm = op.ifm - ofm = op.ofm - assert op.type.is_unary_elementwise_op() - assert None not in (ifm, ofm) - assert ifm.shape == ofm.shape - - rank = len(ofm.shape) - if rank > 3: - n = rank - 3 - ofm_decomp_shape = Shape4D(ofm.shape[0:n]) - new_ofm_shape = [ofm_decomp_shape.elements()] + ofm.shape[n:] - op.ifm_shapes.append(Shape4D(new_ofm_shape)) - op.ofm_shapes.append(Shape4D(new_ofm_shape)) - - if new_ofm_shape[0] == 1: - return - - for height in range(ofm_decomp_shape.height): - for width in range(ofm_decomp_shape.width): - for depth in range(ofm_decomp_shape.depth): - ifm_offset, ofm_offset = Shape4D(0, height, width, depth) - create_elem_part_op(op, ifm_offset, None, ofm_offset) - - ifm.consumer_list.remove(op) - ofm.ops.remove(op) - return - - -def decomp_binary_elementwise(op): + rank = len(op.ofm.shape) + binary = op.ifm2 is not None + new_ofm_shape = [] + new_ifm_shape = [] + new_ifm2_shape = [] + for idx in range(rank): + if op.ofm.shape[idx] != 1: + new_ofm_shape.append(op.ofm.shape[idx]) + new_ifm_shape.append(op.ifm.shape[idx]) + if binary: + new_ifm2_shape.append(op.ifm2.shape[idx]) + if new_ofm_shape == []: + new_ofm_shape = [1] + new_ifm_shape = [1] + new_ifm2_shape = [1] if binary else None + + return new_ofm_shape, new_ifm_shape, new_ifm2_shape + + +def decomp_dims_elementwise(op): """ - Decompose binary elementwise ops with Rank > 3 (H,W,D). + Decompose elementwise ops with Rank > 3 (H,W,D). If Rank > 3, all the dimensions above H are viewed as the N dimension. the elementwise operation will be decomposed to N (of ofm) elementwise operations. By reading and writing with offsets from/to the ifm(s)/ofm. - Note: Broadcast need to be handled, and TOSA allowes for broadcast by both ifm and ifm2 + Note: Broadcast need to be handled for binary elementwise ops, and TOSA allowes for broadcast by both ifm and ifm2 """ ifm = op.ifm ifm2 = op.ifm2 ofm = op.ofm - assert op.type.is_binary_elementwise_op() - assert None not in (ifm, ifm2, ofm) + binary = op.ifm2 is not None + assert len(ofm.shape) <= 6 + + # Remove dimensions that are all 1 + new_ofm_shape, new_ifm_shape, new_ifm2_shape = get_elem_shapes_removed_singles(op) + rank = len(new_ofm_shape) - rank = len(ofm.shape) if rank > 3: n = rank - 3 - ofm_decomp_shape = Shape4D(ofm.shape[0:n]) - ifm_decomp_shape = Shape4D(ifm.shape[0:n]) - ifm2_decomp_shape = Shape4D(ifm2.shape[0:n]) - + ofm_decomp_shape = Shape4D(new_ofm_shape[0:n]) ofm_decomp_stride = get_nhwc_stride(ofm_decomp_shape) - ifm_decomp_stride = get_nhwc_stride(ifm_decomp_shape) - ifm2_decomp_stride = get_nhwc_stride(ifm2_decomp_shape) - - new_ofm_shape = [ofm_decomp_shape.elements()] + ofm.shape[n:] - new_ifm_shape = [ifm_decomp_shape.elements()] + ifm.shape[n:] - new_ifm2_shape = [ifm2_decomp_shape.elements()] + ifm2.shape[n:] - - op.ofm_shapes.append(Shape4D(new_ofm_shape)) - op.ifm_shapes.append(Shape4D(new_ifm_shape)) - op.ifm_shapes.append(Shape4D(new_ifm2_shape)) - - if new_ifm_shape[0] == new_ifm2_shape[0] == new_ofm_shape[0] == 1: - return + ofm_part_shape = Shape4D(new_ofm_shape[n:]) + op.ofm_shapes.append(Shape4D([ofm_decomp_shape.elements()] + new_ofm_shape[n:])) + + if binary: + ifm_decomp_shape = Shape4D(new_ifm_shape[0:n]) + ifm2_decomp_shape = Shape4D(new_ifm2_shape[0:n]) + ifm_decomp_stride = get_nhwc_stride(ifm_decomp_shape) + ifm2_decomp_stride = get_nhwc_stride(ifm2_decomp_shape) + ifm_part_shape = Shape4D(new_ifm_shape[n:]) + ifm2_part_shape = Shape4D(new_ifm2_shape[n:]) + op.ifm_shapes.append(Shape4D([ifm_decomp_shape.elements()] + new_ifm_shape[n:])) + op.ifm_shapes.append(Shape4D([ifm2_decomp_shape.elements()] + new_ifm2_shape[n:])) + else: + op.ifm_shapes.append(Shape4D([ofm_decomp_shape.elements()] + new_ofm_shape[n:])) + op_list = [] for height in range(ofm_decomp_shape.height): for width in range(ofm_decomp_shape.width): for depth in range(ofm_decomp_shape.depth): ofm_offset = Shape4D(0, height, width, depth) + ofm_offset = Shape4D(ofm_offset.dot_prod(ofm_decomp_stride), 0, 0, 0) + ofm_cut = (ofm_offset, ofm_part_shape) + + if binary: + ifm_d = depth if ifm_decomp_shape.depth == ofm_decomp_shape.depth else 0 + ifm_w = width if ifm_decomp_shape.width == ofm_decomp_shape.width else 0 + ifm_h = height if ifm_decomp_shape.height == ofm_decomp_shape.height else 0 + ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d) + ifm_offset = Shape4D(ifm_offset.dot_prod(ifm_decomp_stride), 0, 0, 0) + ifm_cut = (ifm_offset, ifm_part_shape) + + ifm2_d = depth if ifm2_decomp_shape.depth == ofm_decomp_shape.depth else 0 + ifm2_w = width if ifm2_decomp_shape.width == ofm_decomp_shape.width else 0 + ifm2_h = height if ifm2_decomp_shape.height == ofm_decomp_shape.height else 0 + ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d) + ifm2_offset = Shape4D(ifm2_offset.dot_prod(ifm2_decomp_stride), 0, 0, 0) + ifm2_cut = (ifm2_offset, ifm2_part_shape) + op_list.append(create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut)) + else: + op_list.append(create_elem_part_op(op, ofm_cut, None, ofm_cut)) - ifm_d = depth if ifm_decomp_shape.depth == ofm_decomp_shape.depth else 0 - ifm_w = width if ifm_decomp_shape.width == ofm_decomp_shape.width else 0 - ifm_h = height if ifm_decomp_shape.height == ofm_decomp_shape.height else 0 - ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d) - - ifm2_d = depth if ifm2_decomp_shape.depth == ofm_decomp_shape.depth else 0 - ifm2_w = width if ifm2_decomp_shape.width == ofm_decomp_shape.width else 0 - ifm2_h = height if ifm2_decomp_shape.height == ofm_decomp_shape.height else 0 - ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d) - - ofm_offset_n = ofm_offset.dot_prod(ofm_decomp_stride) - ifm_offset_n = ifm_offset.dot_prod(ifm_decomp_stride) - ifm2_offset_n = ifm2_offset.dot_prod(ifm2_decomp_stride) - create_elem_part_op(op, ifm_offset_n, ifm2_offset_n, ofm_offset_n) - - ifm.consumer_list.remove(op) - ifm2.consumer_list.remove(op) ofm.ops.remove(op) - return + ifm.consumer_list.remove(op) + if binary: + ifm2.consumer_list.remove(op) + else: + op.ofm_shapes.append(Shape4D(new_ofm_shape)) + op.ifm_shapes.append(Shape4D(new_ifm_shape)) + op.ifm_shapes.append(Shape4D(new_ifm2_shape)) + + return [op] def decomp_elementwise(tens, arch, nng): """ Decompose elementwise ops with Rank > 3 (H,W,D). + Decompose size of tensors exceeding NPU max size """ - assert len(tens.ops) == 1 + if len(tens.ops) == 1 and tens.ops[0].type.is_elementwise_op(): + op = tens.ops[0] + rank = len(op.ofm.shape) + assert rank <= 6 + + decomp_list = [] + decomp_list = decomp_dims_elementwise(op) - if tens.ops[0].type.is_binary_elementwise_op(): - decomp_binary_elementwise(tens.ops[0]) - elif tens.ops[0].type.is_unary_elementwise_op(): - decomp_unary_elementwise(tens.ops[0]) + for part_op in decomp_list: + decompose_tensors_hwc(part_op) return tens diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py index f5eddccc..1012a615 100644 --- a/ethosu/vela/tosa_supported_operators.py +++ b/ethosu/vela/tosa_supported_operators.py @@ -117,18 +117,19 @@ class TosaSupportedOperators: # This is for a HW limitation, that is to be resolved in SW later on @classmethod @docstring_format_args(tens_dim_range) - def constraint_tens_dimension(cls, op): - "Tensor dimensions must be in the range [{}, {}]" - tens_min, tens_max = cls.tens_dim_range + def constraint_tens_dimension(self, op): + "Tensor dimensions must be in the range [{}, {}], if not elementwise" + tens_min, tens_max = self.tens_dim_range valid = True extra = [] - tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens] - if not tensors: - tensors = [tens for tens in op.inputs if tens] - for tens in tensors: - if not all(tens_min <= dim <= tens_max for dim in tens.shape): - valid = False - extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}") + if op.type not in self.binary_elem_wise_add_mul_sub: + tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens] + if not tensors: + tensors = [tens for tens in op.inputs if tens] + for tens in tensors: + if not all(tens_min <= dim <= tens_max for dim in tens.shape): + valid = False + extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}") return valid, ", ".join(extra) # TODO This is for a HW limitation, that is to be resolved in SW later on -- cgit v1.2.1