From c2b129dbdd9b493fee904bb07838bb0b9e247e96 Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Thu, 23 Sep 2021 13:52:34 +0200 Subject: TOSA: Decomposition of CONCAT -Added support for unlimited number of dimensions -Added support for Tensors with dimension size exceeding maximum limit of NPU. Signed-off-by: Patrik Gustavsson Change-Id: I3cc7327ac759e69042a600e686160aeb18a5ec59 --- ethosu/vela/tosa_graph_optimiser.py | 144 ++++++++++++++++++++++---------- ethosu/vela/tosa_supported_operators.py | 28 +++++-- 2 files changed, 117 insertions(+), 55 deletions(-) diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py index 1e059cc4..5cd9d210 100644 --- a/ethosu/vela/tosa_graph_optimiser.py +++ b/ethosu/vela/tosa_graph_optimiser.py @@ -265,32 +265,21 @@ def remove_splitsliceread(op, arch): DebugDatabase.add_optimised(op, add_op) -def rewrite_concat_ops(op, arch): +def rewrite_concat(op): if not op.run_on_npu or not op.type == Op.Concat: return - axis_4D = 0 - ofm = op.ofm - ofm.ops = [] offset = 0 - inputs = op.inputs - axis = op.attrs["axis"] + axis_4D = op.attrs["axis4D"] for idx, inp in enumerate(inputs): - op.ifm_shapes[idx] = Shape4D(inp.shape) - if axis >= 0: - axis_4D = axis + (4 - len(inp.shape)) - else: - axis_4D = axis write_offset = [0, 0, 0, 0] write_offset[axis_4D] = offset concat_end = offset + op.ifm_shapes[idx][axis_4D] create_add_for_concat(op, op.name + str(idx) + "_add", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)) offset = concat_end - assert ofm.shape[axis] == offset - - return op + assert op.ofm_shapes[0][axis_4D] == offset def remove_reshapes(op, arch): @@ -503,12 +492,15 @@ def convert_table_to_lut(op, arch, nng): return convert_to_lut(op, table.values, "table") -def decompose_tensors_hwc(op): +def decompose_elem_tensors_hwc(op): + """ + Decomposes elementwise op if any of the ifm(s)/ofm are to large in any dimension to be handled by the NPU + """ max_t_size = 65535 - ofm_shape = op.ofm_shapes[0] - ifm_shape = op.ifm_shapes[0] + ofm_shape = op.write_shape if op.write_shape is not None else op.ofm_shapes[0] + ifm_shape = op.read_shapes[0] if op.read_shapes[0] is not None else op.ifm_shapes[0] ifm2_shape = op.ifm_shapes[1] if op.ifm_shapes[1] else None - + ifm2_shape = op.read_shapes[1] if op.read_shapes[1] is not None else ifm2_shape limit_shape = Shape4D(1, max_t_size, max_t_size, max_t_size) if any(dim_size > max_t_size for dim_size in ofm_shape.as_list()): @@ -536,7 +528,6 @@ def decompose_tensors_hwc(op): ifm2_part_shape = ifm2_shape.clip(ifm2_offset, limit_shape) ifm2_cut = (ifm2_offset, ifm2_part_shape) else: - ifm2_offset = None ifm2_cut = (None, None) create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut) @@ -582,17 +573,23 @@ def get_elem_shapes_removed_singles(op): """ Returns the shapes of ifm(s)/ofms after removing all the dimensions that are 1 for all ifm(s)/ofm """ - rank = len(op.ofm.shape) binary = op.ifm2 is not None + ofm_shape = op.ofm_shapes[0].as_list() if len(op.ofm_shapes) > 0 else op.ofm.shape + ifm_shape = op.ifm_shapes[0].as_list() if len(op.ifm_shapes) > 0 else op.ifm.shape + if binary: + ifm2_shape = op.ifm_shapes[1].as_list() if len(op.ofm_shapes) else op.ifm2.shape + + rank = len(ofm_shape) new_ofm_shape = [] new_ifm_shape = [] new_ifm2_shape = [] for idx in range(rank): - if op.ofm.shape[idx] != 1: - new_ofm_shape.append(op.ofm.shape[idx]) - new_ifm_shape.append(op.ifm.shape[idx]) + if ofm_shape[idx] != 1: + new_ofm_shape.append(ofm_shape[idx]) + new_ifm_shape.append(ifm_shape[idx]) if binary: - new_ifm2_shape.append(op.ifm2.shape[idx]) + new_ifm2_shape.append(ifm2_shape[idx]) + if new_ofm_shape == []: new_ofm_shape = [1] new_ifm_shape = [1] @@ -614,7 +611,6 @@ def decomp_dims_elementwise(op): ifm2 = op.ifm2 ofm = op.ofm binary = op.ifm2 is not None - assert len(ofm.shape) <= 6 # Remove dimensions that are all 1 new_ofm_shape, new_ifm_shape, new_ifm2_shape = get_elem_shapes_removed_singles(op) @@ -679,19 +675,74 @@ def decomp_dims_elementwise(op): def decomp_elementwise(tens, arch, nng): """ - Decompose elementwise ops with Rank > 3 (H,W,D). + Decompose elementwise ops with Rank > 3 (H,W,C). Decompose size of tensors exceeding NPU max size """ - if len(tens.ops) == 1 and tens.ops[0].type.is_elementwise_op(): + tens_ops = tens.ops.copy() + for op in tens_ops: + if op.type.is_elementwise_op(): + decomp_list = decomp_dims_elementwise(op) + for part_op in decomp_list: + decompose_elem_tensors_hwc(part_op) + return tens + + +def reshape_concat_shape(shape, rank, axis): + new_h = 1 + for i in range(axis): + new_h *= shape[i] + new_c = 1 + for i in range(axis + 1, rank): + new_c *= shape[i] + if axis == (rank - 1): + new_shape = [new_h, shape[axis], 1] + else: + new_shape = [new_h, shape[axis], new_c] + return new_shape + + +def reshape_concat(op): + """ + Reshapes concat ops with Rank > 3 (H,W,C). + """ + ofm = op.ofm + rank = len(ofm.shape) + axis = op.attrs["axis"] + if axis < 0: + axis += rank + + if rank > 3: + # Reshape so that axis in to be concatenated is the W dimension + # Reshape inputs + for inp in op.inputs: + new_shape = reshape_concat_shape(inp.shape, rank, axis) + op.ifm_shapes.append(Shape4D(new_shape)) + # Reshape output + new_shape = reshape_concat_shape(ofm.shape, rank, axis) + op.ofm_shapes.append(Shape4D(new_shape)) + op.attrs["axis4D"] = 2 + else: + for inp in op.inputs: + op.ifm_shapes.append(Shape4D(inp.shape)) + op.ofm_shapes.append(Shape4D(ofm.shape)) + op.attrs["axis4D"] = axis + (4 - rank) + + +def decomp_rewrite_concat(tens, arch, nng): + """ + Decompose concat ops with Rank > 3 (H,W,C). + Rewrite of concat to elementwise operations + """ + if len(tens.ops) == 1 and tens.ops[0].type == Op.Concat: op = tens.ops[0] - rank = len(op.ofm.shape) - assert rank <= 6 - decomp_list = [] - decomp_list = decomp_dims_elementwise(op) + reshape_concat(op) + rewrite_concat(op) + + op.ofm.ops.remove(op) + for inp in op.inputs: + inp.consumer_list.remove(op) - for part_op in decomp_list: - decompose_tensors_hwc(part_op) return tens @@ -714,21 +765,27 @@ def supported_operator_check(op, arch, nng): def tosa_optimise_graph(nng, arch): - # Decomposing to 4 dimensions + # TODO the supported operator checking need to be split in semantic and HW checks for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [decomp_elementwise], [], rewrite_unsupported=False + nng, sg, arch, [], [supported_operator_check], rewrite_unsupported=False, ) - # Pre-processing step - pre_process_list = [ - supported_operator_check, - set_ifm_ofm_op_shapes, - ] + # Decomposing and rewrite of concat + for idx, sg in enumerate(nng.subgraphs): + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, sg, arch, [decomp_rewrite_concat], [], rewrite_unsupported=False + ) + # Decomposing of elementwise for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [], pre_process_list, rewrite_unsupported=False, + nng, sg, arch, [decomp_elementwise], [], rewrite_unsupported=False + ) + + for idx, sg in enumerate(nng.subgraphs): + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, sg, arch, [], [set_ifm_ofm_op_shapes], rewrite_unsupported=False, ) # Removal of Transpose @@ -743,11 +800,6 @@ def tosa_optimise_graph(nng, arch): nng, sg, arch, [], [fix_sg_input_output_tosa], rewrite_unsupported=False, ) - # Rewrite concat ops - for idx, sg in enumerate(nng.subgraphs): - rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops]) - sg.refresh_after_modification() - # Removal of reshapes for sg in nng.subgraphs: rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes]) diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py index 1012a615..d71e5750 100644 --- a/ethosu/vela/tosa_supported_operators.py +++ b/ethosu/vela/tosa_supported_operators.py @@ -46,6 +46,10 @@ class TosaSupportedOperators: activation_ops = relu_ops | set((Op.Table,)) pad_ops = set((Op.Pad,)) + rank_unlimited_ops = set((Op.Concat,)) + rank6_limited_ops = elem_wise_ops + batch_enabled_ops = elem_wise_ops | set((Op.Concat,)) + large_tens_dims_enabled_ops = elem_wise_ops | set((Op.Concat,)) npu_post_ops = activation_ops supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | elem_wise_ops | pad_ops @@ -60,8 +64,10 @@ class TosaSupportedOperators: self.generic_constraints = [] self.generic_constraints.append(TosaSupportedOperators.constraint_tens_dtype) self.generic_constraints.append(TosaSupportedOperators.constraint_tens_dimension) # TODO as not supported yet - self.generic_constraints.append(TosaSupportedOperators.constraint_rank) # TODO as not supported yet - self.generic_constraints.append(TosaSupportedOperators.constraint_batch) # TODO as not supported yet + self.generic_constraints.append(TosaSupportedOperators.constraint_rank) # TODO as not supported for all ops yet + self.generic_constraints.append( + TosaSupportedOperators.constraint_batch + ) # TODO as not supported for all ops yet # Setup specific constraints. Note: the order matters self.specific_constraints = defaultdict(list) @@ -118,11 +124,11 @@ class TosaSupportedOperators: @classmethod @docstring_format_args(tens_dim_range) def constraint_tens_dimension(self, op): - "Tensor dimensions must be in the range [{}, {}], if not elementwise" + "Tensor dimensions must be in the range [{}, {}]" tens_min, tens_max = self.tens_dim_range valid = True extra = [] - if op.type not in self.binary_elem_wise_add_mul_sub: + if op.type not in self.large_tens_dims_enabled_ops: tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens] if not tensors: tensors = [tens for tens in op.inputs if tens] @@ -135,16 +141,20 @@ class TosaSupportedOperators: # TODO This is for a HW limitation, that is to be resolved in SW later on @classmethod def constraint_rank(self, op): - "Tensor rank must be <= 4, if not elementwise" + "Tensor rank must be <= 6 or <= 4 depending on operator" valid = True extra = [] - if op.type not in self.binary_elem_wise_add_mul_sub: + if op.type not in self.rank_unlimited_ops: + if op.type in self.rank6_limited_ops: + rank_limit = 6 + else: + rank_limit = 4 tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens] if not tensors: tensors = [tens for tens in op.inputs if tens] for tens in tensors: rank = len(tens.shape) - if not rank <= 4: + if not rank <= rank_limit: valid = False extra.append(f"Tensor '{tens.name}' has rank: {rank}") return valid, ", ".join(extra) @@ -152,10 +162,10 @@ class TosaSupportedOperators: # TODO This is for a HW limitation, that is to be resolved in SW later on @classmethod def constraint_batch(self, op): - "If Tensor rank is 4 batch of ifms/ofm must be 1, if not elementwise" + "If Tensor rank is 4 batch of ifms/ofm must be 1" valid = True extra = [] - if op.type not in self.binary_elem_wise_add_mul_sub: + if op.type not in self.batch_enabled_ops: tensors = [tens for tens in op.get_ifm_ifm2_ofm() if tens] if not tensors: tensors = [tens for tens in op.inputs if tens] -- cgit v1.2.1