diff options
author | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2021-09-20 10:47:47 +0200 |
---|---|---|
committer | Patrik Gustavsson <patrik.gustavsson@arm.com> | 2021-09-20 13:27:15 +0200 |
commit | 46408a8049f6a51dda5bfa8a4c9959e037120265 (patch) | |
tree | 68595457843f3ff4193da0542afbe5de56da8d31 /ethosu/vela/tosa_graph_optimiser.py | |
parent | f436ada9caea87ec2dd686a92e41a15c1dcdeb1d (diff) | |
download | ethos-u-vela-46408a8049f6a51dda5bfa8a4c9959e037120265.tar.gz |
TOSA: Elementwise Rank > 4 and Batch > 1
Added support for elementwise operations:
-Support for up to Rank == 6
-Support for Batch > 1 for Rank == 4
-For binary elementwise ops this includes handling
of broadcasting in dimensions above H-dimension
Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I73850bbfb288077a99bd2ceecbf989172016da24
Diffstat (limited to 'ethosu/vela/tosa_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tosa_graph_optimiser.py | 143 |
1 files changed, 143 insertions, 0 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py index 1ef04449..f4aa4534 100644 --- a/ethosu/vela/tosa_graph_optimiser.py +++ b/ethosu/vela/tosa_graph_optimiser.py @@ -503,6 +503,142 @@ def convert_table_to_lut(op, arch, nng): return convert_to_lut(op, table.values, "table") +def create_elem_part_op(op, ifm_offset_n, ifm2_offset_n, ofm_offset_n): + part_op = op.clone() + offset = Shape4D(0, 0, 0, 0) + + part_op.read_offsets[0] = offset.with_batch(ifm_offset_n) + part_op.read_shapes[0] = op.ifm_shapes[0].with_batch(1) + part_op.write_offset = offset.with_batch(ofm_offset_n) + part_op.write_shape = op.ofm_shapes[0].with_batch(1) + part_op.ifm_shapes = op.ifm_shapes.copy() + part_op.ofm_shapes = op.ofm_shapes.copy() + part_op.ifm.consumer_list.append(part_op) + op.ofm.ops.append(part_op) + if ifm2_offset_n: + part_op.read_offsets[1] = offset.with_batch(ifm2_offset_n) + part_op.read_shapes[1] = op.ifm_shapes[1].with_batch(1) + part_op.ifm2.consumer_list.append(part_op) + + +def get_nhwc_stride(shape): + stride_x = shape.depth + stride_y = shape.width * stride_x + stride_n = shape.height * stride_y + return Shape4D(stride_n, stride_y, stride_x, 1) + + +def decomp_unary_elementwise(op): + """ + Decompose binary elementwise ops with Rank > 3 (H,W,D). + If Rank > 3, all the dimensions above H are viewed as the N dimension. + the elementwise operation will be decomposed to N (of ofm) elementwise operations. + By reading and writing with offsets from/to the ifm/ofm. + """ + ifm = op.ifm + ofm = op.ofm + assert op.type.is_unary_elementwise_op() + assert None not in (ifm, ofm) + assert ifm.shape == ofm.shape + + rank = len(ofm.shape) + if rank > 3: + n = rank - 3 + ofm_decomp_shape = Shape4D(ofm.shape[0:n]) + new_ofm_shape = [ofm_decomp_shape.elements()] + ofm.shape[n:] + op.ifm_shapes.append(Shape4D(new_ofm_shape)) + op.ofm_shapes.append(Shape4D(new_ofm_shape)) + + if new_ofm_shape[0] == 1: + return + + for height in range(ofm_decomp_shape.height): + for width in range(ofm_decomp_shape.width): + for depth in range(ofm_decomp_shape.depth): + ifm_offset, ofm_offset = Shape4D(0, height, width, depth) + create_elem_part_op(op, ifm_offset, None, ofm_offset) + + ifm.consumer_list.remove(op) + ofm.ops.remove(op) + return + + +def decomp_binary_elementwise(op): + """ + Decompose binary elementwise ops with Rank > 3 (H,W,D). + If Rank > 3, all the dimensions above H are viewed as the N dimension. + the elementwise operation will be decomposed to N (of ofm) elementwise operations. + By reading and writing with offsets from/to the ifm(s)/ofm. + Note: Broadcast need to be handled, and TOSA allowes for broadcast by both ifm and ifm2 + """ + + ifm = op.ifm + ifm2 = op.ifm2 + ofm = op.ofm + assert op.type.is_binary_elementwise_op() + assert None not in (ifm, ifm2, ofm) + + rank = len(ofm.shape) + if rank > 3: + n = rank - 3 + ofm_decomp_shape = Shape4D(ofm.shape[0:n]) + ifm_decomp_shape = Shape4D(ifm.shape[0:n]) + ifm2_decomp_shape = Shape4D(ifm2.shape[0:n]) + + ofm_decomp_stride = get_nhwc_stride(ofm_decomp_shape) + ifm_decomp_stride = get_nhwc_stride(ifm_decomp_shape) + ifm2_decomp_stride = get_nhwc_stride(ifm2_decomp_shape) + + new_ofm_shape = [ofm_decomp_shape.elements()] + ofm.shape[n:] + new_ifm_shape = [ifm_decomp_shape.elements()] + ifm.shape[n:] + new_ifm2_shape = [ifm2_decomp_shape.elements()] + ifm2.shape[n:] + + op.ofm_shapes.append(Shape4D(new_ofm_shape)) + op.ifm_shapes.append(Shape4D(new_ifm_shape)) + op.ifm_shapes.append(Shape4D(new_ifm2_shape)) + + if new_ifm_shape[0] == new_ifm2_shape[0] == new_ofm_shape[0] == 1: + return + + for height in range(ofm_decomp_shape.height): + for width in range(ofm_decomp_shape.width): + for depth in range(ofm_decomp_shape.depth): + ofm_offset = Shape4D(0, height, width, depth) + + ifm_d = depth if ifm_decomp_shape.depth == ofm_decomp_shape.depth else 0 + ifm_w = width if ifm_decomp_shape.width == ofm_decomp_shape.width else 0 + ifm_h = height if ifm_decomp_shape.height == ofm_decomp_shape.height else 0 + ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d) + + ifm2_d = depth if ifm2_decomp_shape.depth == ofm_decomp_shape.depth else 0 + ifm2_w = width if ifm2_decomp_shape.width == ofm_decomp_shape.width else 0 + ifm2_h = height if ifm2_decomp_shape.height == ofm_decomp_shape.height else 0 + ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d) + + ofm_offset_n = ofm_offset.dot_prod(ofm_decomp_stride) + ifm_offset_n = ifm_offset.dot_prod(ifm_decomp_stride) + ifm2_offset_n = ifm2_offset.dot_prod(ifm2_decomp_stride) + create_elem_part_op(op, ifm_offset_n, ifm2_offset_n, ofm_offset_n) + + ifm.consumer_list.remove(op) + ifm2.consumer_list.remove(op) + ofm.ops.remove(op) + return + + +def decomp_elementwise(tens, arch, nng): + """ + Decompose elementwise ops with Rank > 3 (H,W,D). + """ + assert len(tens.ops) == 1 + + if tens.ops[0].type.is_binary_elementwise_op(): + decomp_binary_elementwise(tens.ops[0]) + elif tens.ops[0].type.is_unary_elementwise_op(): + decomp_unary_elementwise(tens.ops[0]) + return tens + + def fixup_quantization(op, arch, nng): if op.ifm and op.ifm.quantization.zero_point is None: op.ifm.quantization.zero_point = 0 @@ -521,6 +657,13 @@ def supported_operator_check(op, arch, nng): def tosa_optimise_graph(nng, arch): + + # Decomposing to 4 dimensions + for idx, sg in enumerate(nng.subgraphs): + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, sg, arch, [decomp_elementwise], [], rewrite_unsupported=False + ) + # Pre-processing step pre_process_list = [ supported_operator_check, |