aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tosa_graph_optimiser.py')
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py143
1 files changed, 143 insertions, 0 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index 1ef04449..f4aa4534 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -503,6 +503,142 @@ def convert_table_to_lut(op, arch, nng):
return convert_to_lut(op, table.values, "table")
+def create_elem_part_op(op, ifm_offset_n, ifm2_offset_n, ofm_offset_n):
+ part_op = op.clone()
+ offset = Shape4D(0, 0, 0, 0)
+
+ part_op.read_offsets[0] = offset.with_batch(ifm_offset_n)
+ part_op.read_shapes[0] = op.ifm_shapes[0].with_batch(1)
+ part_op.write_offset = offset.with_batch(ofm_offset_n)
+ part_op.write_shape = op.ofm_shapes[0].with_batch(1)
+ part_op.ifm_shapes = op.ifm_shapes.copy()
+ part_op.ofm_shapes = op.ofm_shapes.copy()
+ part_op.ifm.consumer_list.append(part_op)
+ op.ofm.ops.append(part_op)
+ if ifm2_offset_n:
+ part_op.read_offsets[1] = offset.with_batch(ifm2_offset_n)
+ part_op.read_shapes[1] = op.ifm_shapes[1].with_batch(1)
+ part_op.ifm2.consumer_list.append(part_op)
+
+
+def get_nhwc_stride(shape):
+ stride_x = shape.depth
+ stride_y = shape.width * stride_x
+ stride_n = shape.height * stride_y
+ return Shape4D(stride_n, stride_y, stride_x, 1)
+
+
+def decomp_unary_elementwise(op):
+ """
+ Decompose binary elementwise ops with Rank > 3 (H,W,D).
+ If Rank > 3, all the dimensions above H are viewed as the N dimension.
+ the elementwise operation will be decomposed to N (of ofm) elementwise operations.
+ By reading and writing with offsets from/to the ifm/ofm.
+ """
+ ifm = op.ifm
+ ofm = op.ofm
+ assert op.type.is_unary_elementwise_op()
+ assert None not in (ifm, ofm)
+ assert ifm.shape == ofm.shape
+
+ rank = len(ofm.shape)
+ if rank > 3:
+ n = rank - 3
+ ofm_decomp_shape = Shape4D(ofm.shape[0:n])
+ new_ofm_shape = [ofm_decomp_shape.elements()] + ofm.shape[n:]
+ op.ifm_shapes.append(Shape4D(new_ofm_shape))
+ op.ofm_shapes.append(Shape4D(new_ofm_shape))
+
+ if new_ofm_shape[0] == 1:
+ return
+
+ for height in range(ofm_decomp_shape.height):
+ for width in range(ofm_decomp_shape.width):
+ for depth in range(ofm_decomp_shape.depth):
+ ifm_offset, ofm_offset = Shape4D(0, height, width, depth)
+ create_elem_part_op(op, ifm_offset, None, ofm_offset)
+
+ ifm.consumer_list.remove(op)
+ ofm.ops.remove(op)
+ return
+
+
+def decomp_binary_elementwise(op):
+ """
+ Decompose binary elementwise ops with Rank > 3 (H,W,D).
+ If Rank > 3, all the dimensions above H are viewed as the N dimension.
+ the elementwise operation will be decomposed to N (of ofm) elementwise operations.
+ By reading and writing with offsets from/to the ifm(s)/ofm.
+ Note: Broadcast need to be handled, and TOSA allowes for broadcast by both ifm and ifm2
+ """
+
+ ifm = op.ifm
+ ifm2 = op.ifm2
+ ofm = op.ofm
+ assert op.type.is_binary_elementwise_op()
+ assert None not in (ifm, ifm2, ofm)
+
+ rank = len(ofm.shape)
+ if rank > 3:
+ n = rank - 3
+ ofm_decomp_shape = Shape4D(ofm.shape[0:n])
+ ifm_decomp_shape = Shape4D(ifm.shape[0:n])
+ ifm2_decomp_shape = Shape4D(ifm2.shape[0:n])
+
+ ofm_decomp_stride = get_nhwc_stride(ofm_decomp_shape)
+ ifm_decomp_stride = get_nhwc_stride(ifm_decomp_shape)
+ ifm2_decomp_stride = get_nhwc_stride(ifm2_decomp_shape)
+
+ new_ofm_shape = [ofm_decomp_shape.elements()] + ofm.shape[n:]
+ new_ifm_shape = [ifm_decomp_shape.elements()] + ifm.shape[n:]
+ new_ifm2_shape = [ifm2_decomp_shape.elements()] + ifm2.shape[n:]
+
+ op.ofm_shapes.append(Shape4D(new_ofm_shape))
+ op.ifm_shapes.append(Shape4D(new_ifm_shape))
+ op.ifm_shapes.append(Shape4D(new_ifm2_shape))
+
+ if new_ifm_shape[0] == new_ifm2_shape[0] == new_ofm_shape[0] == 1:
+ return
+
+ for height in range(ofm_decomp_shape.height):
+ for width in range(ofm_decomp_shape.width):
+ for depth in range(ofm_decomp_shape.depth):
+ ofm_offset = Shape4D(0, height, width, depth)
+
+ ifm_d = depth if ifm_decomp_shape.depth == ofm_decomp_shape.depth else 0
+ ifm_w = width if ifm_decomp_shape.width == ofm_decomp_shape.width else 0
+ ifm_h = height if ifm_decomp_shape.height == ofm_decomp_shape.height else 0
+ ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d)
+
+ ifm2_d = depth if ifm2_decomp_shape.depth == ofm_decomp_shape.depth else 0
+ ifm2_w = width if ifm2_decomp_shape.width == ofm_decomp_shape.width else 0
+ ifm2_h = height if ifm2_decomp_shape.height == ofm_decomp_shape.height else 0
+ ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d)
+
+ ofm_offset_n = ofm_offset.dot_prod(ofm_decomp_stride)
+ ifm_offset_n = ifm_offset.dot_prod(ifm_decomp_stride)
+ ifm2_offset_n = ifm2_offset.dot_prod(ifm2_decomp_stride)
+ create_elem_part_op(op, ifm_offset_n, ifm2_offset_n, ofm_offset_n)
+
+ ifm.consumer_list.remove(op)
+ ifm2.consumer_list.remove(op)
+ ofm.ops.remove(op)
+ return
+
+
+def decomp_elementwise(tens, arch, nng):
+ """
+ Decompose elementwise ops with Rank > 3 (H,W,D).
+ """
+ assert len(tens.ops) == 1
+
+ if tens.ops[0].type.is_binary_elementwise_op():
+ decomp_binary_elementwise(tens.ops[0])
+ elif tens.ops[0].type.is_unary_elementwise_op():
+ decomp_unary_elementwise(tens.ops[0])
+ return tens
+
+
def fixup_quantization(op, arch, nng):
if op.ifm and op.ifm.quantization.zero_point is None:
op.ifm.quantization.zero_point = 0
@@ -521,6 +657,13 @@ def supported_operator_check(op, arch, nng):
def tosa_optimise_graph(nng, arch):
+
+ # Decomposing to 4 dimensions
+ for idx, sg in enumerate(nng.subgraphs):
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+ nng, sg, arch, [decomp_elementwise], [], rewrite_unsupported=False
+ )
+
# Pre-processing step
pre_process_list = [
supported_operator_check,