diff options
Diffstat (limited to 'ethosu/vela/graph_optimiser.py')
-rw-r--r-- | ethosu/vela/graph_optimiser.py | 69 |
1 files changed, 53 insertions, 16 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index 4806001f..fdb0fae0 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -75,7 +75,7 @@ def rewrite_concat(tens, arch, nng): new_op = Operation(Op.ConcatSliceWrite, concat_op.name + str(idx)) new_op.inputs = [inp] new_op.outputs = [tens] - new_op.attrs["concat_axis"] = axis + new_op.attrs["concat_axis"] = axis + (4 - len(inp.shape)) new_op.attrs["concat_start"] = offset offset += inp.shape[axis] new_op.attrs["concat_end"] = offset @@ -116,21 +116,20 @@ def rewrite_split(tens, arch, nng): # be calculated from the index of the output tensor if axis is not None: # Get the start and end of the split - offset_start = [0] * len(tens.shape) - offset_end = [0] * len(tens.shape) - for out in outputs: + offset_start = [0] * 4 + for idx, out in enumerate(outputs): if out == tens: break - offset_start[axis] += out.shape[axis] + axis_4D = axis + (4 - len(out.shape)) + offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D] # If start offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input if (offset_start[-1] % 16) != 0: inp.avoid_NHCWB16 = True - - offset_end[axis] = offset_start[axis] + tens.shape[axis] + else: + offset_start = full_shape(4, offset_start, 0) new_op.attrs["split_start"] = offset_start - new_op.attrs["split_end"] = offset_end new_op.run_on_npu = True new_op.set_output_tensor(tens) DebugDatabase.add_optimised(split_op, new_op) @@ -217,6 +216,8 @@ def convert_resizebilinear_1x1_to_add(op): # Set the add inputs op.inputs[1] = op.inputs[0] op.inputs[0] = tens + op.ifm_shapes = [] + op.ofm_shapes = [] return op @@ -321,13 +322,16 @@ def convert_batched_fc_shape(op, arch, nng): ifm = op.inputs[0] ofm = op.outputs[0] # Check if the FC is 2D and first dimension indicates batching - if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1: + # TOD0 op.ifm_shape[0] > 1 is enough when refactory is complete + if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1 and op.ifm_shapes[0][0] > 1: n = ifm.shape[0] batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)} h, w = batching_split.get(n, (1, n)) prev_op = ifm.ops[0] desired_shape = [1, h, w, ifm.shape[-1]] + op.ifm_shapes[0] = desired_shape + if len(ifm.consumer_list) == 1 and prev_op is not None and prev_op.type == Op.Reshape: # There is a preceding Reshape # Compare input of prev_op and input of op, to see if prev_op can be removed @@ -352,6 +356,8 @@ def convert_batched_fc_shape(op, arch, nng): weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape)) desired_shape = [1, h, w, ofm.shape[-1]] + op.ofm_shapes[0] = desired_shape + if ( len(ofm.consumer_list) == 1 and ofm.consumer_list[0] is not None @@ -451,6 +457,7 @@ def fixup_stridedslice_output(tens, arch, nng): new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, tens.shape) for idx, out_tens in enumerate(op.outputs): + op.ofm_shapes[idx] = new_shape_tens reshape_in = out_tens.clone("_reshaped") reshape_in.set_all_shapes(reshape_input_shape) reshape_in.ops = [op] @@ -489,7 +496,6 @@ def fixup_unpack_output(tens, arch, nng): DebugDatabase.add_optimised(op, reshape_op) op.outputs[idx] = reshape_in - return tens @@ -582,7 +588,7 @@ def convert_conv_to_fc(op, arch, nng): # caching/double buffering for the weights. # (Weights dont need to be reloaded for convs when IFM H and W are 1) if op.type == Op.Conv2DBias: - _, h, w, _ = op.inputs[0].shape + _, h, w, _ = op.ifm_shapes[0] kh, kw, _, _ = op.inputs[1].shape if h == 1 and w == 1 and kh == 1 and kw == 1: # Overwrite this op as a Fully Connected Op @@ -595,6 +601,7 @@ def convert_conv_to_fc(op, arch, nng): weight_tensor = op.inputs[1] weight_tensor.quant_values = weight_tensor.quant_values.squeeze(axis=(0, 1)) weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape)) + # The output from a fully connected is expected to be 2D so we need to add a reshape layer to convert it # back to 4D afterwards as the next layer is expecting that shape orig_ofm_tensor = op.outputs[0] @@ -609,6 +616,7 @@ def convert_conv_to_fc(op, arch, nng): reshape_op.attrs["new_shape"] = orig_ofm_tensor.shape reshape_op.inputs = [fc_ofm_tensor, new_shape_tens] reshape_op.set_output_tensor(orig_ofm_tensor) + # Replace this ops OFM to point to the 2D tensor op.outputs[0] = fc_ofm_tensor # Record optimisation in debug database @@ -651,6 +659,8 @@ def fixup_act_reorder(op, arch, nng): prep_op = get_prepend_op(op) if prep_op is not None: act_op = op.clone("_reordered") + act_op.ifm_shapes = list(op.ifm_shapes) + act_op.ofm_shapes = list(op.ofm_shapes) # There is only one input tensor, overwrite it act_op.set_input_tensor(prep_op.inputs[0], 0) @@ -658,6 +668,8 @@ def fixup_act_reorder(op, arch, nng): act_op_out = act_op.inputs[0].clone("_acted") act_op_out.quantization = op.outputs[0].quantization.clone() act_op.set_output_tensor(act_op_out) + act_op.ifm_shapes[0] = full_shape(4, prep_op.inputs[0].shape, 1) + act_op.ofm_shapes[0] = full_shape(4, act_op_out.shape, 1) # Update the consumer list act_op_out.consumer_list = op.outputs[0].consumer_list.copy() @@ -704,6 +716,15 @@ def set_tensor_equivalence(op, arch, nng): return op +def set_ifm_ofm_op_shapes(op, arch, nng): + if op.run_on_npu and op.type.needs_shapes(): + if op.ifm_shapes or op.ofm_shapes: + # Shapes already set + return op + op.set_ifm_ofm_shapes() + return op + + def convert_softmax(op, arch, nng): if op.type == Op.Softmax and op.run_on_npu: softmax = SoftMax(op) @@ -839,7 +860,7 @@ def convert_lrelu_to_mul_max(op, arch): mul_identity.add_input_tensor(identity_tens) fm_id = ofm.clone(op.name + "_id") mul_identity.set_output_tensor(fm_id) - DebugDatabase.add_optimised(op, mul_alpha) + DebugDatabase.add_optimised(op, mul_identity) # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs op.type = Op.Maximum @@ -869,6 +890,8 @@ def convert_to_lut(op, lut_values, lut_name): quantization.zero_point = 0 tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], np.uint8, quantization=quantization) op.add_input_tensor(tens) + op.ifm_shapes.append(full_shape(4, tens.shape, 1)) + # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale), # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions # should be the same as the IFM @@ -1072,10 +1095,20 @@ def optimise_graph_a(nng, arch, verbose_graph=False): if verbose_graph: nng.print_graph() + pre_process_list = [ + supported_operator_check, + set_ifm_ofm_op_shapes, + # TODO: memory-only Op removal + ] + + for idx, sg in enumerate(nng.subgraphs): + # rewrite graph pass + nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( + nng, sg, arch, [], pre_process_list, rewrite_unsupported=False, + ) + op_rewrite_list = [ set_tensor_equivalence, - supported_operator_check, - # then do any rewrites of supported operators convert_depthwise_to_conv, convert_conv_to_fc, convert_softmax, @@ -1106,7 +1139,7 @@ def optimise_graph_a(nng, arch, verbose_graph=False): for idx, sg in enumerate(nng.subgraphs): # remove passthrough tensors and attempt further optimizations nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev, add_padding_fields] + nng, sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev, add_padding_fields], ) # Post-optimisation operator debug tracing @@ -1125,7 +1158,11 @@ def optimise_graph_b(nng, arch, verbose_graph=False): for idx, sg in enumerate(nng.subgraphs): # combined rewrite graph pass nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split], [] + nng, + sg, + arch, + [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split], + [set_ifm_ofm_op_shapes], ) if verbose_graph: |