From bf31d647dc5df47410ee577b12427ddf076d816b Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Wed, 16 Dec 2020 13:08:06 +0100 Subject: MLBEDSW-3645 4D class for op ifm/ofm shapes Add 4D shape class for op Ifm/ofm shapes Signed-off-by: Patrik Gustavsson Change-Id: Ic0a98da9d2f9d085605e39a9ab5a26bad6e702a3 --- ethosu/vela/graph_optimiser.py | 42 ++++++++++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 16 deletions(-) (limited to 'ethosu/vela/graph_optimiser.py') diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index fdb0fae0..1128a311 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -37,6 +37,7 @@ from .operation import Op from .operation import Operation from .operation import Padding from .operation_util import create_avgpool_nop +from .shape4d import Shape4D from .softmax import SoftMax from .tensor import check_quantized_tens_scaling_equal from .tensor import create_const_tensor @@ -82,6 +83,7 @@ def rewrite_concat(tens, arch, nng): new_op.run_on_npu = True tens.ops.append(new_op) DebugDatabase.add_optimised(concat_op, new_op) + new_op.set_ifm_ofm_shapes() assert tens.shape[axis] == offset # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a @@ -121,7 +123,8 @@ def rewrite_split(tens, arch, nng): if out == tens: break axis_4D = axis + (4 - len(out.shape)) - offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D] + + offset_start[axis_4D] += split_op.ofm_shapes[idx].get_dim(axis_4D) # If start offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input if (offset_start[-1] % 16) != 0: @@ -132,6 +135,7 @@ def rewrite_split(tens, arch, nng): new_op.attrs["split_start"] = offset_start new_op.run_on_npu = True new_op.set_output_tensor(tens) + new_op.set_ifm_ofm_shapes() DebugDatabase.add_optimised(split_op, new_op) return tens @@ -189,6 +193,7 @@ def fixup_conv2d_backprop(op, arch, nng): if op.type == Op.Conv2DBackpropInput: # flip the inputs op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0] + op.set_ifm_ofm_shapes() op.type = Op.Conv2DBackpropInputSwitchedBias # Update strides @@ -216,8 +221,7 @@ def convert_resizebilinear_1x1_to_add(op): # Set the add inputs op.inputs[1] = op.inputs[0] op.inputs[0] = tens - op.ifm_shapes = [] - op.ofm_shapes = [] + op.set_ifm_ofm_shapes() return op @@ -323,14 +327,14 @@ def convert_batched_fc_shape(op, arch, nng): ofm = op.outputs[0] # Check if the FC is 2D and first dimension indicates batching # TOD0 op.ifm_shape[0] > 1 is enough when refactory is complete - if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1 and op.ifm_shapes[0][0] > 1: + if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1 and op.ifm_shapes[0].batch > 1: n = ifm.shape[0] batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)} h, w = batching_split.get(n, (1, n)) prev_op = ifm.ops[0] desired_shape = [1, h, w, ifm.shape[-1]] - op.ifm_shapes[0] = desired_shape + op.ifm_shapes[0] = Shape4D(desired_shape) if len(ifm.consumer_list) == 1 and prev_op is not None and prev_op.type == Op.Reshape: # There is a preceding Reshape @@ -356,7 +360,7 @@ def convert_batched_fc_shape(op, arch, nng): weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape)) desired_shape = [1, h, w, ofm.shape[-1]] - op.ofm_shapes[0] = desired_shape + op.ofm_shapes[0] = Shape4D(desired_shape) if ( len(ofm.consumer_list) == 1 @@ -395,6 +399,7 @@ def fixup_pack_input(op, arch, nng): reshape_op.attrs["new_shape"] = desired_shape reshape_op.inputs = [inp, new_shape_tens] reshape_op.set_output_tensor(reshape_out) + reshape_op.set_ifm_ofm_shapes() DebugDatabase.add_optimised(op, reshape_op) op.inputs[idx] = reshape_out @@ -413,6 +418,7 @@ def unfuse_activation_function(op, arch, nng): act_op.set_output_tensor(out_tens) act_op.add_input_tensor(intermediate_tens) op.set_output_tensor(intermediate_tens) + act_op.set_ifm_ofm_shapes() return op @@ -457,7 +463,7 @@ def fixup_stridedslice_output(tens, arch, nng): new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, tens.shape) for idx, out_tens in enumerate(op.outputs): - op.ofm_shapes[idx] = new_shape_tens + op.ofm_shapes[idx] = Shape4D(new_shape_tens.shape) reshape_in = out_tens.clone("_reshaped") reshape_in.set_all_shapes(reshape_input_shape) reshape_in.ops = [op] @@ -466,6 +472,7 @@ def fixup_stridedslice_output(tens, arch, nng): reshape_op.attrs["new_shape"] = reshape_input_shape reshape_op.inputs = [reshape_in, new_shape_tens] reshape_op.set_output_tensor(out_tens) + reshape_op.set_ifm_ofm_shapes() op.outputs[idx] = reshape_in @@ -493,6 +500,7 @@ def fixup_unpack_output(tens, arch, nng): reshape_op.attrs["new_shape"] = reshape_input_shape reshape_op.inputs = [reshape_in, new_shape_tens] reshape_op.set_output_tensor(out_tens) + reshape_op.set_ifm_ofm_shapes() DebugDatabase.add_optimised(op, reshape_op) op.outputs[idx] = reshape_in @@ -588,7 +596,8 @@ def convert_conv_to_fc(op, arch, nng): # caching/double buffering for the weights. # (Weights dont need to be reloaded for convs when IFM H and W are 1) if op.type == Op.Conv2DBias: - _, h, w, _ = op.ifm_shapes[0] + h = op.ifm_shapes[0].height + w = op.ifm_shapes[0].width kh, kw, _, _ = op.inputs[1].shape if h == 1 and w == 1 and kh == 1 and kw == 1: # Overwrite this op as a Fully Connected Op @@ -616,9 +625,11 @@ def convert_conv_to_fc(op, arch, nng): reshape_op.attrs["new_shape"] = orig_ofm_tensor.shape reshape_op.inputs = [fc_ofm_tensor, new_shape_tens] reshape_op.set_output_tensor(orig_ofm_tensor) + reshape_op.set_ifm_ofm_shapes() # Replace this ops OFM to point to the 2D tensor op.outputs[0] = fc_ofm_tensor + op.set_ifm_ofm_shapes() # Record optimisation in debug database DebugDatabase.add_optimised(op, reshape_op) DebugDatabase.add_optimised(op, op) @@ -649,6 +660,7 @@ def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng): relu_fused_op.add_input_tensor(ifm) relu_fused_op.set_output_tensor(ofm) + relu_fused_op.set_ifm_ofm_shapes() op = relu_fused_op return op @@ -668,8 +680,8 @@ def fixup_act_reorder(op, arch, nng): act_op_out = act_op.inputs[0].clone("_acted") act_op_out.quantization = op.outputs[0].quantization.clone() act_op.set_output_tensor(act_op_out) - act_op.ifm_shapes[0] = full_shape(4, prep_op.inputs[0].shape, 1) - act_op.ofm_shapes[0] = full_shape(4, act_op_out.shape, 1) + act_op.ifm_shapes[0] = Shape4D(prep_op.inputs[0].shape) + act_op.ofm_shapes[0] = Shape4D(act_op_out.shape) # Update the consumer list act_op_out.consumer_list = op.outputs[0].consumer_list.copy() @@ -839,6 +851,7 @@ def convert_lrelu_to_mul_max(op, arch): mul_alpha.add_input_tensor(alpha_tens) fm_alpha = ofm.clone(op.name + "_alpha") mul_alpha.set_output_tensor(fm_alpha) + mul_alpha.set_ifm_ofm_shapes() DebugDatabase.add_optimised(op, mul_alpha) if check_quantized_tens_scaling_equal(ifm, ofm): @@ -860,6 +873,7 @@ def convert_lrelu_to_mul_max(op, arch): mul_identity.add_input_tensor(identity_tens) fm_id = ofm.clone(op.name + "_id") mul_identity.set_output_tensor(fm_id) + mul_identity.set_ifm_ofm_shapes() DebugDatabase.add_optimised(op, mul_identity) # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs @@ -890,7 +904,7 @@ def convert_to_lut(op, lut_values, lut_name): quantization.zero_point = 0 tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], np.uint8, quantization=quantization) op.add_input_tensor(tens) - op.ifm_shapes.append(full_shape(4, tens.shape, 1)) + op.ifm_shapes.append(Shape4D(tens.shape)) # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale), # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions @@ -1158,11 +1172,7 @@ def optimise_graph_b(nng, arch, verbose_graph=False): for idx, sg in enumerate(nng.subgraphs): # combined rewrite graph pass nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, - sg, - arch, - [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split], - [set_ifm_ofm_op_shapes], + nng, sg, arch, [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split], [], ) if verbose_graph: -- cgit v1.2.1