aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tosa_graph_optimiser.py')
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py234
1 files changed, 145 insertions, 89 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index f4aa4534..1e059cc4 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -503,21 +503,71 @@ def convert_table_to_lut(op, arch, nng):
return convert_to_lut(op, table.values, "table")
-def create_elem_part_op(op, ifm_offset_n, ifm2_offset_n, ofm_offset_n):
- part_op = op.clone()
- offset = Shape4D(0, 0, 0, 0)
+def decompose_tensors_hwc(op):
+ max_t_size = 65535
+ ofm_shape = op.ofm_shapes[0]
+ ifm_shape = op.ifm_shapes[0]
+ ifm2_shape = op.ifm_shapes[1] if op.ifm_shapes[1] else None
+
+ limit_shape = Shape4D(1, max_t_size, max_t_size, max_t_size)
+
+ if any(dim_size > max_t_size for dim_size in ofm_shape.as_list()):
+ ofm_split = ofm_shape.floordiv_const(max_t_size).add(1, 1, 1, 1)
- part_op.read_offsets[0] = offset.with_batch(ifm_offset_n)
- part_op.read_shapes[0] = op.ifm_shapes[0].with_batch(1)
- part_op.write_offset = offset.with_batch(ofm_offset_n)
- part_op.write_shape = op.ofm_shapes[0].with_batch(1)
+ for height in range(ofm_split.height):
+ for width in range(ofm_split.width):
+ for depth in range(ofm_split.depth):
+ ofm_offset = Shape4D(0, height * max_t_size, width * max_t_size, depth * max_t_size)
+ ofm_part_shape = ofm_shape.clip(ofm_offset, limit_shape)
+ ofm_cut = (ofm_offset, ofm_part_shape)
+
+ ifm_d = depth * max_t_size if ifm_shape.depth == ofm_shape.depth else 0
+ ifm_w = width * max_t_size if ifm_shape.width == ofm_shape.width else 0
+ ifm_h = height * max_t_size if ifm_shape.height == ofm_shape.height else 0
+ ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d)
+ ifm_part_shape = ifm_shape.clip(ifm_offset, limit_shape)
+ ifm_cut = (ifm_offset, ifm_part_shape)
+
+ if ifm2_shape is not None:
+ ifm2_d = depth * max_t_size if ifm2_shape.depth == ofm_shape.depth else 0
+ ifm2_w = width * max_t_size if ifm2_shape.width == ofm_shape.width else 0
+ ifm2_h = height * max_t_size if ifm2_shape.height == ofm_shape.height else 0
+ ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d)
+ ifm2_part_shape = ifm2_shape.clip(ifm2_offset, limit_shape)
+ ifm2_cut = (ifm2_offset, ifm2_part_shape)
+ else:
+ ifm2_offset = None
+ ifm2_cut = (None, None)
+
+ create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut)
+ op.ofm.ops.remove(op)
+ op.ifm.consumer_list.remove(op)
+ if op.ifm2 is not None:
+ op.ifm2.consumer_list.remove(op)
+ return
+
+
+def create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut):
+ part_op = op.clone()
+ ifm_read_offset = op.read_offsets[0] if op.read_offsets[0] is not None else Shape4D(0, 0, 0, 0)
+ ofm_write_offset = op.write_offset if op.write_offset is not None else Shape4D(0, 0, 0, 0)
+ ifm_offset, ifm_shape = ifm_cut
+ ofm_offset, ofm_shape = ofm_cut
+
+ part_op.read_offsets[0] = ifm_read_offset + ifm_offset
+ part_op.read_shapes[0] = ifm_shape
+ part_op.write_offset = ofm_write_offset + ofm_offset
+ part_op.write_shape = ofm_shape
part_op.ifm_shapes = op.ifm_shapes.copy()
part_op.ofm_shapes = op.ofm_shapes.copy()
part_op.ifm.consumer_list.append(part_op)
op.ofm.ops.append(part_op)
- if ifm2_offset_n:
- part_op.read_offsets[1] = offset.with_batch(ifm2_offset_n)
- part_op.read_shapes[1] = op.ifm_shapes[1].with_batch(1)
+
+ ifm2_offset, ifm2_shape = ifm2_cut
+ if ifm2_offset:
+ ifm2_read_offset = op.read_offsets[1] if op.read_offsets[1] is not None else Shape4D(0, 0, 0, 0)
+ part_op.read_offsets[1] = ifm2_read_offset + ifm2_offset
+ part_op.read_shapes[1] = ifm2_shape
part_op.ifm2.consumer_list.append(part_op)
@@ -528,114 +578,120 @@ def get_nhwc_stride(shape):
return Shape4D(stride_n, stride_y, stride_x, 1)
-def decomp_unary_elementwise(op):
+def get_elem_shapes_removed_singles(op):
"""
- Decompose binary elementwise ops with Rank > 3 (H,W,D).
- If Rank > 3, all the dimensions above H are viewed as the N dimension.
- the elementwise operation will be decomposed to N (of ofm) elementwise operations.
- By reading and writing with offsets from/to the ifm/ofm.
+ Returns the shapes of ifm(s)/ofms after removing all the dimensions that are 1 for all ifm(s)/ofm
"""
- ifm = op.ifm
- ofm = op.ofm
- assert op.type.is_unary_elementwise_op()
- assert None not in (ifm, ofm)
- assert ifm.shape == ofm.shape
-
- rank = len(ofm.shape)
- if rank > 3:
- n = rank - 3
- ofm_decomp_shape = Shape4D(ofm.shape[0:n])
- new_ofm_shape = [ofm_decomp_shape.elements()] + ofm.shape[n:]
- op.ifm_shapes.append(Shape4D(new_ofm_shape))
- op.ofm_shapes.append(Shape4D(new_ofm_shape))
-
- if new_ofm_shape[0] == 1:
- return
-
- for height in range(ofm_decomp_shape.height):
- for width in range(ofm_decomp_shape.width):
- for depth in range(ofm_decomp_shape.depth):
- ifm_offset, ofm_offset = Shape4D(0, height, width, depth)
- create_elem_part_op(op, ifm_offset, None, ofm_offset)
-
- ifm.consumer_list.remove(op)
- ofm.ops.remove(op)
- return
-
-
-def decomp_binary_elementwise(op):
+ rank = len(op.ofm.shape)
+ binary = op.ifm2 is not None
+ new_ofm_shape = []
+ new_ifm_shape = []
+ new_ifm2_shape = []
+ for idx in range(rank):
+ if op.ofm.shape[idx] != 1:
+ new_ofm_shape.append(op.ofm.shape[idx])
+ new_ifm_shape.append(op.ifm.shape[idx])
+ if binary:
+ new_ifm2_shape.append(op.ifm2.shape[idx])
+ if new_ofm_shape == []:
+ new_ofm_shape = [1]
+ new_ifm_shape = [1]
+ new_ifm2_shape = [1] if binary else None
+
+ return new_ofm_shape, new_ifm_shape, new_ifm2_shape
+
+
+def decomp_dims_elementwise(op):
"""
- Decompose binary elementwise ops with Rank > 3 (H,W,D).
+ Decompose elementwise ops with Rank > 3 (H,W,D).
If Rank > 3, all the dimensions above H are viewed as the N dimension.
the elementwise operation will be decomposed to N (of ofm) elementwise operations.
By reading and writing with offsets from/to the ifm(s)/ofm.
- Note: Broadcast need to be handled, and TOSA allowes for broadcast by both ifm and ifm2
+ Note: Broadcast need to be handled for binary elementwise ops, and TOSA allowes for broadcast by both ifm and ifm2
"""
ifm = op.ifm
ifm2 = op.ifm2
ofm = op.ofm
- assert op.type.is_binary_elementwise_op()
- assert None not in (ifm, ifm2, ofm)
+ binary = op.ifm2 is not None
+ assert len(ofm.shape) <= 6
+
+ # Remove dimensions that are all 1
+ new_ofm_shape, new_ifm_shape, new_ifm2_shape = get_elem_shapes_removed_singles(op)
+ rank = len(new_ofm_shape)
- rank = len(ofm.shape)
if rank > 3:
n = rank - 3
- ofm_decomp_shape = Shape4D(ofm.shape[0:n])
- ifm_decomp_shape = Shape4D(ifm.shape[0:n])
- ifm2_decomp_shape = Shape4D(ifm2.shape[0:n])
-
+ ofm_decomp_shape = Shape4D(new_ofm_shape[0:n])
ofm_decomp_stride = get_nhwc_stride(ofm_decomp_shape)
- ifm_decomp_stride = get_nhwc_stride(ifm_decomp_shape)
- ifm2_decomp_stride = get_nhwc_stride(ifm2_decomp_shape)
-
- new_ofm_shape = [ofm_decomp_shape.elements()] + ofm.shape[n:]
- new_ifm_shape = [ifm_decomp_shape.elements()] + ifm.shape[n:]
- new_ifm2_shape = [ifm2_decomp_shape.elements()] + ifm2.shape[n:]
-
- op.ofm_shapes.append(Shape4D(new_ofm_shape))
- op.ifm_shapes.append(Shape4D(new_ifm_shape))
- op.ifm_shapes.append(Shape4D(new_ifm2_shape))
-
- if new_ifm_shape[0] == new_ifm2_shape[0] == new_ofm_shape[0] == 1:
- return
+ ofm_part_shape = Shape4D(new_ofm_shape[n:])
+ op.ofm_shapes.append(Shape4D([ofm_decomp_shape.elements()] + new_ofm_shape[n:]))
+
+ if binary:
+ ifm_decomp_shape = Shape4D(new_ifm_shape[0:n])
+ ifm2_decomp_shape = Shape4D(new_ifm2_shape[0:n])
+ ifm_decomp_stride = get_nhwc_stride(ifm_decomp_shape)
+ ifm2_decomp_stride = get_nhwc_stride(ifm2_decomp_shape)
+ ifm_part_shape = Shape4D(new_ifm_shape[n:])
+ ifm2_part_shape = Shape4D(new_ifm2_shape[n:])
+ op.ifm_shapes.append(Shape4D([ifm_decomp_shape.elements()] + new_ifm_shape[n:]))
+ op.ifm_shapes.append(Shape4D([ifm2_decomp_shape.elements()] + new_ifm2_shape[n:]))
+ else:
+ op.ifm_shapes.append(Shape4D([ofm_decomp_shape.elements()] + new_ofm_shape[n:]))
+ op_list = []
for height in range(ofm_decomp_shape.height):
for width in range(ofm_decomp_shape.width):
for depth in range(ofm_decomp_shape.depth):
ofm_offset = Shape4D(0, height, width, depth)
+ ofm_offset = Shape4D(ofm_offset.dot_prod(ofm_decomp_stride), 0, 0, 0)
+ ofm_cut = (ofm_offset, ofm_part_shape)
+
+ if binary:
+ ifm_d = depth if ifm_decomp_shape.depth == ofm_decomp_shape.depth else 0
+ ifm_w = width if ifm_decomp_shape.width == ofm_decomp_shape.width else 0
+ ifm_h = height if ifm_decomp_shape.height == ofm_decomp_shape.height else 0
+ ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d)
+ ifm_offset = Shape4D(ifm_offset.dot_prod(ifm_decomp_stride), 0, 0, 0)
+ ifm_cut = (ifm_offset, ifm_part_shape)
+
+ ifm2_d = depth if ifm2_decomp_shape.depth == ofm_decomp_shape.depth else 0
+ ifm2_w = width if ifm2_decomp_shape.width == ofm_decomp_shape.width else 0
+ ifm2_h = height if ifm2_decomp_shape.height == ofm_decomp_shape.height else 0
+ ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d)
+ ifm2_offset = Shape4D(ifm2_offset.dot_prod(ifm2_decomp_stride), 0, 0, 0)
+ ifm2_cut = (ifm2_offset, ifm2_part_shape)
+ op_list.append(create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut))
+ else:
+ op_list.append(create_elem_part_op(op, ofm_cut, None, ofm_cut))
- ifm_d = depth if ifm_decomp_shape.depth == ofm_decomp_shape.depth else 0
- ifm_w = width if ifm_decomp_shape.width == ofm_decomp_shape.width else 0
- ifm_h = height if ifm_decomp_shape.height == ofm_decomp_shape.height else 0
- ifm_offset = Shape4D(0, ifm_h, ifm_w, ifm_d)
-
- ifm2_d = depth if ifm2_decomp_shape.depth == ofm_decomp_shape.depth else 0
- ifm2_w = width if ifm2_decomp_shape.width == ofm_decomp_shape.width else 0
- ifm2_h = height if ifm2_decomp_shape.height == ofm_decomp_shape.height else 0
- ifm2_offset = Shape4D(0, ifm2_h, ifm2_w, ifm2_d)
-
- ofm_offset_n = ofm_offset.dot_prod(ofm_decomp_stride)
- ifm_offset_n = ifm_offset.dot_prod(ifm_decomp_stride)
- ifm2_offset_n = ifm2_offset.dot_prod(ifm2_decomp_stride)
- create_elem_part_op(op, ifm_offset_n, ifm2_offset_n, ofm_offset_n)
-
- ifm.consumer_list.remove(op)
- ifm2.consumer_list.remove(op)
ofm.ops.remove(op)
- return
+ ifm.consumer_list.remove(op)
+ if binary:
+ ifm2.consumer_list.remove(op)
+ else:
+ op.ofm_shapes.append(Shape4D(new_ofm_shape))
+ op.ifm_shapes.append(Shape4D(new_ifm_shape))
+ op.ifm_shapes.append(Shape4D(new_ifm2_shape))
+
+ return [op]
def decomp_elementwise(tens, arch, nng):
"""
Decompose elementwise ops with Rank > 3 (H,W,D).
+ Decompose size of tensors exceeding NPU max size
"""
- assert len(tens.ops) == 1
+ if len(tens.ops) == 1 and tens.ops[0].type.is_elementwise_op():
+ op = tens.ops[0]
+ rank = len(op.ofm.shape)
+ assert rank <= 6
+
+ decomp_list = []
+ decomp_list = decomp_dims_elementwise(op)
- if tens.ops[0].type.is_binary_elementwise_op():
- decomp_binary_elementwise(tens.ops[0])
- elif tens.ops[0].type.is_unary_elementwise_op():
- decomp_unary_elementwise(tens.ops[0])
+ for part_op in decomp_list:
+ decompose_tensors_hwc(part_op)
return tens