From ee99bb124b088430b97d205df9fc90a1e9412e0c Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Thu, 8 Apr 2021 09:04:00 +0200 Subject: MLBEDSW-4334 Non-linear format decision in graph opt. Check if non linear tensor format can be used is refactored. -Flag avoid_NHCWB16 replaced with needs_linear_format -Checking restrictions located to one function in graph optimiser. Signed-off-by: Patrik Gustavsson Change-Id: Iec5c7996a1a6039cad052197f1ae56f7c0290440 --- ethosu/vela/graph_optimiser.py | 138 ++++++++++++++++++++++++++++++++--------- 1 file changed, 109 insertions(+), 29 deletions(-) (limited to 'ethosu/vela/graph_optimiser.py') diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index 56932dbe..dd540a79 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -104,8 +104,6 @@ def rewrite_concat_ops(op, arch): for idx, inp in enumerate(op.inputs): op.ifm_shapes[idx] = Shape4D(desired_shape) - if Shape4D(inp.shape) != op.ifm_shapes[idx]: - inp.avoid_NHCWB16 = True op.type = Op.PackReshaped inputs, axis = op.get_concat_inputs_axis() @@ -125,12 +123,7 @@ def rewrite_concat_ops(op, arch): offset = concat_end assert ofm.shape[axis] == offset - # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a - # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte - # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0 - # and those addresses are always 16 byte aligned due to the NHCWB16 format. - if axis == -1 or axis == (len(ofm.shape) - 1): - ofm.avoid_NHCWB16 = any(op2.write_offset.depth % 16 != 0 for op2 in ofm.ops if op2.write_offset is not None) + return op def rewrite_split_ops(tens, arch, nng): @@ -171,10 +164,6 @@ def rewrite_split_ops(tens, arch, nng): offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D] - # If start offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input - if (offset_start[-1] % 16) != 0: - inp.avoid_NHCWB16 = True - new_op.read_offsets[0] = Shape4D.from_list(offset_start, 0) new_op.run_on_npu = True new_op.set_output_tensor(tens) @@ -224,6 +213,108 @@ def remove_SplitSliceRead(op, arch): DebugDatabase.add_optimised(op, avgpool_op) +def avoid_nhcwb16_for_concat(tens): + # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a + # multiple of 16. This as, it is only then the address offset for the ofm, for all operations, will be 16 byte + # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0 + # and those addresses are always 16 byte aligned due to the NHCWB16 format. + return any(op.write_offset.depth % 16 != 0 for op in tens.ops if op.write_offset is not None) + + +def avoid_nhcwb16_for_split(tens): + # If read offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input + for cons_op in tens.consumer_list: + if cons_op.ifm == tens: + read_offset = cons_op.read_offsets[0] + elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == tens: + read_offset = cons_op.read_offsets[1] + else: + assert False + if read_offset is not None and (read_offset[-1] % 16) != 0: + return True + return False + + +def avoid_nhcwb16_for_shapes(tens): + # check all producers/consumers to see if any op shape is preventing NHCWB16 + for cons_op in tens.consumer_list: + if cons_op.ifm == tens: + cons_op_shape = cons_op.ifm_shapes[0] + elif cons_op.type.is_binary_elementwise_op() and cons_op.ifm2 == tens: + cons_op_shape = cons_op.ifm_shapes[1] + else: + assert False + if Shape4D(tens.shape) != cons_op_shape: + return True + + for prod_op in tens.ops: + if Shape4D(tens.shape) != prod_op.ofm_shapes[0]: + return True + + return False + + +# Check if non linear format can be used +def check_format_restrictions(tens, arch): + if len(tens.ops) < 1: + return + if tens.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const) or any( + cons is None for cons in tens.consumer_list + ): + return + + if not any(cons.run_on_npu for cons in tens.consumer_list): + return + if not any(prod.run_on_npu for prod in tens.ops): + return + + # "Concat" ofm exception: + if avoid_nhcwb16_for_concat(tens): + return + + # "Split" ifm exception: + if avoid_nhcwb16_for_split(tens): + return + + # Shapes checking: check all producers/consumers are NHCWB16 compatible with tens.shape + if avoid_nhcwb16_for_shapes(tens): + return + + for op in tens.consumer_list: + if op.type == Op.ReduceSum and tens.dtype == DataType.int32: + return + if op.type == Op.Reshape: + # Using NHCWB16 format for a no-op reshape is only an option if subsequent + # consumers do not also need to perform a reshape or if the OFM is going to + # be processed by CPU operations. No-op reshape consumers with empty lists + # (those that have no consumers, or null-consumers used as list terminators) + # must use normal NHWC output. + + def incompatible_consumers(oper): + if oper and oper.type == Op.Reshape: + for consumer in oper.outputs[0].consumer_list: + yield from incompatible_consumers(consumer) + yield not oper or not oper.run_on_npu + + if not any(incompatible_consumers(op)): + + def get_rewrites(oper): + if oper and oper.type == Op.Reshape: + for consumer in oper.outputs[0].consumer_list: + yield from get_rewrites(consumer) + yield oper + + # Detect no-op reshapes by comparing their full input and output tensor shapes. + inshape = op.ifm_shapes[0] + compatible_shape = [(inshape == oper.ofm_shapes[0]) for oper in get_rewrites(op)] + if not (compatible_shape and all(compatible_shape)): + return + else: + return + + tens.needs_linear_format = False + + def insert_copy_op_after_tens(tens): tens_cons_list_copy = tens.consumer_list.copy() @@ -459,8 +550,6 @@ def rewrite_fully_connected_input(op, arch, nng): assert batch_size * n_in_elems == elms op.ifm_shapes[0] = Shape4D([batch_size, 1, 1, n_in_elems]) - if Shape4D(op.ifm.shape) != op.ifm_shapes[0]: - op.ifm.avoid_NHCWB16 = True return op @@ -473,8 +562,6 @@ def convert_batched_fc_shape(op, arch, nng): h, w = batching_split.get(n, (1, n)) op.ifm_shapes[0] = Shape4D([1, h, w, op.ifm_shapes[0].depth]) - op.ifm.avoid_NHCWB16 = True - # Reshape Weights to be 4D. IO becomes HWIO weight_tensor = op.inputs[1] weight_tensor.quant_values = np.expand_dims(np.expand_dims(weight_tensor.quant_values, axis=0), axis=0) @@ -483,7 +570,6 @@ def convert_batched_fc_shape(op, arch, nng): n = op.ofm_shapes[0].batch h, w = batching_split.get(n, (1, n)) op.ofm_shapes[0] = Shape4D([1, h, w, op.ofm_shapes[0].depth]) - op.ofm.avoid_NHCWB16 = True return op @@ -550,9 +636,6 @@ def rewrite_stridedslice_output(op, arch, nng): axis_4D[idx] = axis op.ofm_shapes[idx] = Shape4D(output_shape) - if op.ofm_shapes[idx] != Shape4D(out_tens.shape): - out_tens.avoid_NHCWB16 = True - op.attrs["split_axis_4D"] = axis_4D return op @@ -574,8 +657,6 @@ def rewrite_unpack_output(op, arch, nng): for idx, out_tens in enumerate(op.outputs): op.ofm_shapes[idx] = Shape4D(desired_output_shape) axis_4D_list[idx] = axis_4D - if op.ofm_shapes[idx] != Shape4D(out_tens.shape): - out_tens.avoid_NHCWB16 = True op.attrs["split_axis_4D"] = axis_4D_list return op @@ -662,7 +743,6 @@ def optimise_strided_conv(op, arch, nng): ifm_shape = op.ifm_shapes[0] # IFM op.ifm_shapes[0] = Shape4D([ifm_shape.batch, ifm_shape.height, ifm_shape.width // 2, ifm_shape.depth * 2]) - op.ifm.avoid_NHCWB16 = True # Weights weight_shape = weight_tensor.shape @@ -1129,16 +1209,12 @@ def remove_reshapes(op, arch): for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs): if cons_ifm == ifm: ifm_cons.set_input_tensor(ofm, ifm_idx) - if op.ifm_shapes[0] != op.ofm_shapes[0]: - ofm.avoid_NHCWB16 = True else: # Bypassed Reshape by replacing ofm with ifm for cons in ofm.consumer_list: for ifm_idx, cons_ifm in enumerate(cons.inputs): if cons_ifm == ofm: cons.set_input_tensor(ifm, ifm_idx) - if op.ifm_shapes[0] != op.ofm_shapes[0]: - ifm.avoid_NHCWB16 = True def check_reshapes(op, arch): @@ -1339,7 +1415,7 @@ def convert_pad(op: Operation, arch, nng): create_avg_pool_for_concat( op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right) ) - ofm.avoid_NHCWB16 = True + op.type = Op.ConcatTFLite return avgpool_op @@ -1531,7 +1607,6 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng): if h > 64: shape = [shape[0], 1, h * w, shape[3]] op.ifm_shapes[0] = Shape4D(shape) - inp.avoid_NHCWB16 = True if h > 256 and op.type == Op.AvgPool: op.attrs.update({"ksize": (1, 1, h * w, 1), "filter_height": 1, "filter_width": h * w}) @@ -1688,6 +1763,11 @@ def optimise_graph_a(nng, arch, verbose_graph=False): rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_SplitSliceRead]) sg.refresh_after_modification() + # Check Tensor Format restrictions + for sg in nng.subgraphs: + rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [check_format_restrictions], []) + sg.refresh_after_modification() + # Post-optimisation operator debug tracing, and checking that no undesired reshapes are left in the graph for sg in nng.subgraphs: rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [check_reshapes, _record_optimised]) -- cgit v1.2.1