diff options
author | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2021-09-23 13:52:34 +0200 |
---|---|---|
committer | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2021-09-28 13:57:12 +0200 |
commit | c2b129dbdd9b493fee904bb07838bb0b9e247e96 (patch) | |
tree | 9a4e02909ebf74839e459278aa5ecb59e3fa278a /ethosu/vela/tosa_graph_optimiser.py | |
parent | 3f22ec2025c8e1afe6780785fd8c62c015824a63 (diff) | |
download | ethos-u-vela-c2b129dbdd9b493fee904bb07838bb0b9e247e96.tar.gz |
TOSA: Decomposition of CONCAT
-Added support for unlimited number of dimensions
-Added support for Tensors with dimension size
exceeding maximum limit of NPU.
Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I3cc7327ac759e69042a600e686160aeb18a5ec59
Diffstat (limited to 'ethosu/vela/tosa_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tosa_graph_optimiser.py | 144 |
1 files changed, 98 insertions, 46 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py index 1e059cc4..5cd9d210 100644 --- a/ethosu/vela/tosa_graph_optimiser.py +++ b/ethosu/vela/tosa_graph_optimiser.py @@ -265,32 +265,21 @@ def remove_splitsliceread(op, arch): DebugDatabase.add_optimised(op, add_op) -def rewrite_concat_ops(op, arch): +def rewrite_concat(op): if not op.run_on_npu or not op.type == Op.Concat: return - axis_4D = 0 - ofm = op.ofm - ofm.ops = [] offset = 0 - inputs = op.inputs - axis = op.attrs["axis"] + axis_4D = op.attrs["axis4D"] for idx, inp in enumerate(inputs): - op.ifm_shapes[idx] = Shape4D(inp.shape) - if axis >= 0: - axis_4D = axis + (4 - len(inp.shape)) - else: - axis_4D = axis write_offset = [0, 0, 0, 0] write_offset[axis_4D] = offset concat_end = offset + op.ifm_shapes[idx][axis_4D] create_add_for_concat(op, op.name + str(idx) + "_add", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)) offset = concat_end - assert ofm.shape[axis] == offset - - return op + assert op.ofm_shapes[0][axis_4D] == offset def remove_reshapes(op, arch): @@ -503,12 +492,15 @@ def convert_table_to_lut(op, arch, nng): return convert_to_lut(op, table.values, "table") -def decompose_tensors_hwc(op): +def decompose_elem_tensors_hwc(op): + """ + Decomposes elementwise op if any of the ifm(s)/ofm are to large in any dimension to be handled by the NPU + """ max_t_size = 65535 - ofm_shape = op.ofm_shapes[0] - ifm_shape = op.ifm_shapes[0] + ofm_shape = op.write_shape if op.write_shape is not None else op.ofm_shapes[0] + ifm_shape = op.read_shapes[0] if op.read_shapes[0] is not None else op.ifm_shapes[0] ifm2_shape = op.ifm_shapes[1] if op.ifm_shapes[1] else None - + ifm2_shape = op.read_shapes[1] if op.read_shapes[1] is not None else ifm2_shape limit_shape = Shape4D(1, max_t_size, max_t_size, max_t_size) if any(dim_size > max_t_size for dim_size in ofm_shape.as_list()): @@ -536,7 +528,6 @@ def decompose_tensors_hwc(op): ifm2_part_shape = ifm2_shape.clip(ifm2_offset, limit_shape) ifm2_cut = (ifm2_offset, ifm2_part_shape) else: - ifm2_offset = None ifm2_cut = (None, None) create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut) @@ -582,17 +573,23 @@ def get_elem_shapes_removed_singles(op): """ Returns the shapes of ifm(s)/ofms after removing all the dimensions that are 1 for all ifm(s)/ofm """ - rank = len(op.ofm.shape) binary = op.ifm2 is not None + ofm_shape = op.ofm_shapes[0].as_list() if len(op.ofm_shapes) > 0 else op.ofm.shape + ifm_shape = op.ifm_shapes[0].as_list() if len(op.ifm_shapes) > 0 else op.ifm.shape + if binary: + ifm2_shape = op.ifm_shapes[1].as_list() if len(op.ofm_shapes) else op.ifm2.shape + + rank = len(ofm_shape) new_ofm_shape = [] new_ifm_shape = [] new_ifm2_shape = [] for idx in range(rank): - if op.ofm.shape[idx] != 1: - new_ofm_shape.append(op.ofm.shape[idx]) - new_ifm_shape.append(op.ifm.shape[idx]) + if ofm_shape[idx] != 1: + new_ofm_shape.append(ofm_shape[idx]) + new_ifm_shape.append(ifm_shape[idx]) if binary: - new_ifm2_shape.append(op.ifm2.shape[idx]) + new_ifm2_shape.append(ifm2_shape[idx]) + if new_ofm_shape == []: new_ofm_shape = [1] new_ifm_shape = [1] @@ -614,7 +611,6 @@ def decomp_dims_elementwise(op): ifm2 = op.ifm2 ofm = op.ofm binary = op.ifm2 is not None - assert len(ofm.shape) <= 6 # Remove dimensions that are all 1 new_ofm_shape, new_ifm_shape, new_ifm2_shape = get_elem_shapes_removed_singles(op) @@ -679,19 +675,74 @@ def decomp_dims_elementwise(op): def decomp_elementwise(tens, arch, nng): """ - Decompose elementwise ops with Rank > 3 (H,W,D). + Decompose elementwise ops with Rank > 3 (H,W,C). Decompose size of tensors exceeding NPU max size """ - if len(tens.ops) == 1 and tens.ops[0].type.is_elementwise_op(): + tens_ops = tens.ops.copy() + for op in tens_ops: + if op.type.is_elementwise_op(): + decomp_list = decomp_dims_elementwise(op) + for part_op in decomp_list: + decompose_elem_tensors_hwc(part_op) + return tens + + +def reshape_concat_shape(shape, rank, axis): + new_h = 1 + for i in range(axis): + new_h *= shape[i] + new_c = 1 + for i in range(axis + 1, rank): + new_c *= shape[i] + if axis == (rank - 1): + new_shape = [new_h, shape[axis], 1] + else: + new_shape = [new_h, shape[axis], new_c] + return new_shape + + +def reshape_concat(op): + """ + Reshapes concat ops with Rank > 3 (H,W,C). + """ + ofm = op.ofm + rank = len(ofm.shape) + axis = op.attrs["axis"] + if axis < 0: + axis += rank + + if rank > 3: + # Reshape so that axis in to be concatenated is the W dimension + # Reshape inputs + for inp in op.inputs: + new_shape = reshape_concat_shape(inp.shape, rank, axis) + op.ifm_shapes.append(Shape4D(new_shape)) + # Reshape output + new_shape = reshape_concat_shape(ofm.shape, rank, axis) + op.ofm_shapes.append(Shape4D(new_shape)) + op.attrs["axis4D"] = 2 + else: + for inp in op.inputs: + op.ifm_shapes.append(Shape4D(inp.shape)) + op.ofm_shapes.append(Shape4D(ofm.shape)) + op.attrs["axis4D"] = axis + (4 - rank) + + +def decomp_rewrite_concat(tens, arch, nng): + """ + Decompose concat ops with Rank > 3 (H,W,C). + Rewrite of concat to elementwise operations + """ + if len(tens.ops) == 1 and tens.ops[0].type == Op.Concat: op = tens.ops[0] - rank = len(op.ofm.shape) - assert rank <= 6 - decomp_list = [] - decomp_list = decomp_dims_elementwise(op) + reshape_concat(op) + rewrite_concat(op) + + op.ofm.ops.remove(op) + for inp in op.inputs: + inp.consumer_list.remove(op) - for part_op in decomp_list: - decompose_tensors_hwc(part_op) return tens @@ -714,21 +765,27 @@ def supported_operator_check(op, arch, nng): def tosa_optimise_graph(nng, arch): - # Decomposing to 4 dimensions + # TODO the supported operator checking need to be split in semantic and HW checks for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [decomp_elementwise], [], rewrite_unsupported=False + nng, sg, arch, [], [supported_operator_check], rewrite_unsupported=False, ) - # Pre-processing step - pre_process_list = [ - supported_operator_check, - set_ifm_ofm_op_shapes, - ] + # Decomposing and rewrite of concat + for idx, sg in enumerate(nng.subgraphs): + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, sg, arch, [decomp_rewrite_concat], [], rewrite_unsupported=False + ) + # Decomposing of elementwise for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [], pre_process_list, rewrite_unsupported=False, + nng, sg, arch, [decomp_elementwise], [], rewrite_unsupported=False + ) + + for idx, sg in enumerate(nng.subgraphs): + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, sg, arch, [], [set_ifm_ofm_op_shapes], rewrite_unsupported=False, ) # Removal of Transpose @@ -743,11 +800,6 @@ def tosa_optimise_graph(nng, arch): nng, sg, arch, [], [fix_sg_input_output_tosa], rewrite_unsupported=False, ) - # Rewrite concat ops - for idx, sg in enumerate(nng.subgraphs): - rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops]) - sg.refresh_after_modification() - # Removal of reshapes for sg in nng.subgraphs: rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes]) |