aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/graph_optimiser.py
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2021-04-08 09:04:00 +0200
committerPatrik Gustavsson <patrik.gustavsson@arm.com>2021-04-08 11:21:35 +0200
commitee99bb124b088430b97d205df9fc90a1e9412e0c (patch)
treebfb1a6799f9e9ff9d0f387471277b8a26edbab70 /ethosu/vela/graph_optimiser.py
parent95b279f1454d58a93238851cb5ff394c7782ad32 (diff)
downloadethos-u-vela-ee99bb124b088430b97d205df9fc90a1e9412e0c.tar.gz
MLBEDSW-4334 Non-linear format decision in graph opt.
Check if non linear tensor format can be used is refactored. -Flag avoid_NHCWB16 replaced with needs_linear_format -Checking restrictions located to one function in graph optimiser. Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: Iec5c7996a1a6039cad052197f1ae56f7c0290440
Diffstat (limited to 'ethosu/vela/graph_optimiser.py')
-rw-r--r--ethosu/vela/graph_optimiser.py138
1 files changed, 109 insertions, 29 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 56932dbe..dd540a79 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -104,8 +104,6 @@ def rewrite_concat_ops(op, arch):
for idx, inp in enumerate(op.inputs):
op.ifm_shapes[idx] = Shape4D(desired_shape)
- if Shape4D(inp.shape) != op.ifm_shapes[idx]:
- inp.avoid_NHCWB16 = True
op.type = Op.PackReshaped
inputs, axis = op.get_concat_inputs_axis()
@@ -125,12 +123,7 @@ def rewrite_concat_ops(op, arch):
offset = concat_end
assert ofm.shape[axis] == offset
- # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
- # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte
- # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0
- # and those addresses are always 16 byte aligned due to the NHCWB16 format.
- if axis == -1 or axis == (len(ofm.shape) - 1):
- ofm.avoid_NHCWB16 = any(op2.write_offset.depth % 16 != 0 for op2 in ofm.ops if op2.write_offset is not None)
+ return op
def rewrite_split_ops(tens, arch, nng):
@@ -171,10 +164,6 @@ def rewrite_split_ops(tens, arch, nng):
offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
- # If start offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input
- if (offset_start[-1] % 16) != 0:
- inp.avoid_NHCWB16 = True
-
new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0)
new_op.run_on_npu = True
new_op.set_output_tensor(tens)
@@ -224,6 +213,108 @@ def remove_SplitSliceRead(op, arch):
DebugDatabase.add_optimised(op, avgpool_op)
+def avoid_nhcwb16_for_concat(tens):
+ # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
+ # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte
+ # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0
+ # and those addresses are always 16 byte aligned due to the NHCWB16 format.
+ return any(op.write_offset.depth % 16 != 0 for op in tens.ops if op.write_offset is not None)
+
+
+def avoid_nhcwb16_for_split(tens):
+ # If read offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input
+ for cons_op in tens.consumer_list:
+ if cons_op.ifm == tens:
+ read_offset = cons_op.read_offsets[0]
+ elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == tens:
+ read_offset = cons_op.read_offsets[1]
+ else:
+ assert False
+ if read_offset is not None and (read_offset[-1] % 16) != 0:
+ return True
+ return False
+
+
+def avoid_nhcwb16_for_shapes(tens):
+ # check all producers/consumers to see if any op shape is preventing NHCWB16
+ for cons_op in tens.consumer_list:
+ if cons_op.ifm == tens:
+ cons_op_shape = cons_op.ifm_shapes[0]
+ elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == tens:
+ cons_op_shape = cons_op.ifm_shapes[1]
+ else:
+ assert False
+ if Shape4D(tens.shape) != cons_op_shape:
+ return True
+
+ for prod_op in tens.ops:
+ if Shape4D(tens.shape) != prod_op.ofm_shapes[0]:
+ return True
+
+ return False
+
+
+# Check if non linear format can be used
+def check_format_restrictions(tens, arch):
+ if len(tens.ops) < 1:
+ return
+ if tens.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const) or any(
+ cons is None for cons in tens.consumer_list
+ ):
+ return
+
+ if not any(cons.run_on_npu for cons in tens.consumer_list):
+ return
+ if not any(prod.run_on_npu for prod in tens.ops):
+ return
+
+ # "Concat" ofm exception:
+ if avoid_nhcwb16_for_concat(tens):
+ return
+
+ # "Split" ifm exception:
+ if avoid_nhcwb16_for_split(tens):
+ return
+
+ # Shapes checking: check all producers/consumers are NHCWB16 compatible with tens.shape
+ if avoid_nhcwb16_for_shapes(tens):
+ return
+
+ for op in tens.consumer_list:
+ if op.type == Op.ReduceSum and tens.dtype == DataType.int32:
+ return
+ if op.type == Op.Reshape:
+ # Using NHCWB16 format for a no-op reshape is only an option if subsequent
+ # consumers do not also need to perform a reshape or if the OFM is going to
+ # be processed by CPU operations. No-op reshape consumers with empty lists
+ # (those that have no consumers, or null-consumers used as list terminators)
+ # must use normal NHWC output.
+
+ def incompatible_consumers(oper):
+ if oper and oper.type == Op.Reshape:
+ for consumer in oper.outputs[0].consumer_list:
+ yield from incompatible_consumers(consumer)
+ yield not oper or not oper.run_on_npu
+
+ if not any(incompatible_consumers(op)):
+
+ def get_rewrites(oper):
+ if oper and oper.type == Op.Reshape:
+ for consumer in oper.outputs[0].consumer_list:
+ yield from get_rewrites(consumer)
+ yield oper
+
+ # Detect no-op reshapes by comparing their full input and output tensor shapes.
+ inshape = op.ifm_shapes[0]
+ compatible_shape = [(inshape == oper.ofm_shapes[0]) for oper in get_rewrites(op)]
+ if not (compatible_shape and all(compatible_shape)):
+ return
+ else:
+ return
+
+ tens.needs_linear_format = False
+
+
def insert_copy_op_after_tens(tens):
tens_cons_list_copy = tens.consumer_list.copy()
@@ -459,8 +550,6 @@ def rewrite_fully_connected_input(op, arch, nng):
assert batch_size * n_in_elems == elms
op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems])
- if Shape4D(op.ifm.shape) != op.ifm_shapes[0]:
- op.ifm.avoid_NHCWB16 = True
return op
@@ -473,8 +562,6 @@ def convert_batched_fc_shape(op, arch, nng):
h, w = batching_split.get(n, (1, n))
op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth])
- op.ifm.avoid_NHCWB16 = True
-
# Reshape Weights to be 4D. IO becomes HWIO
weight_tensor = op.inputs[1]
weight_tensor.quant_values = np.expand_dims(np.expand_dims(weight_tensor.quant_values, axis=0), axis=0)
@@ -483,7 +570,6 @@ def convert_batched_fc_shape(op, arch, nng):
n = op.ofm_shapes[0].batch
h, w = batching_split.get(n, (1, n))
op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth])
- op.ofm.avoid_NHCWB16 = True
return op
@@ -550,9 +636,6 @@ def rewrite_stridedslice_output(op, arch, nng):
axis_4D[idx] = axis
op.ofm_shapes[idx] = Shape4D(output_shape)
- if op.ofm_shapes[idx] != Shape4D(out_tens.shape):
- out_tens.avoid_NHCWB16 = True
-
op.attrs["split_axis_4D"] = axis_4D
return op
@@ -574,8 +657,6 @@ def rewrite_unpack_output(op, arch, nng):
for idx, out_tens in enumerate(op.outputs):
op.ofm_shapes[idx] = Shape4D(desired_output_shape)
axis_4D_list[idx] = axis_4D
- if op.ofm_shapes[idx] != Shape4D(out_tens.shape):
- out_tens.avoid_NHCWB16 = True
op.attrs["split_axis_4D"] = axis_4D_list
return op
@@ -662,7 +743,6 @@ def optimise_strided_conv(op, arch, nng):
ifm_shape = op.ifm_shapes[0]
# IFM
op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2])
- op.ifm.avoid_NHCWB16 = True
# Weights
weight_shape = weight_tensor.shape
@@ -1129,16 +1209,12 @@ def remove_reshapes(op, arch):
for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs):
if cons_ifm == ifm:
ifm_cons.set_input_tensor(ofm, ifm_idx)
- if op.ifm_shapes[0] != op.ofm_shapes[0]:
- ofm.avoid_NHCWB16 = True
else:
# Bypassed Reshape by replacing ofm with ifm
for cons in ofm.consumer_list:
for ifm_idx, cons_ifm in enumerate(cons.inputs):
if cons_ifm == ofm:
cons.set_input_tensor(ifm, ifm_idx)
- if op.ifm_shapes[0] != op.ofm_shapes[0]:
- ifm.avoid_NHCWB16 = True
def check_reshapes(op, arch):
@@ -1339,7 +1415,7 @@ def convert_pad(op: Operation, arch, nng):
create_avg_pool_for_concat(
op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
)
- ofm.avoid_NHCWB16 = True
+
op.type = Op.ConcatTFLite
return avgpool_op
@@ -1531,7 +1607,6 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
if h > 64:
shape = [shape[0], 1, h * w, shape[3]]
op.ifm_shapes[0] = Shape4D(shape)
- inp.avoid_NHCWB16 = True
if h > 256 and op.type == Op.AvgPool:
op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w})
@@ -1688,6 +1763,11 @@ def optimise_graph_a(nng, arch, verbose_graph=False):
rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead])
sg.refresh_after_modification()
+ # Check Tensor Format restrictions
+ for sg in nng.subgraphs:
+ rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [check_format_restrictions], [])
+ sg.refresh_after_modification()
+
# Post-optimisation operator debug tracing, and checking that no undesired reshapes are left in the graph
for sg in nng.subgraphs:
rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [check_reshapes, _record_optimised])