From cb33704fcd7859b1c334f996445bba2f4efea5f9 Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Wed, 16 Sep 2020 14:55:40 +0200 Subject: MLBEDSW-1693 Convert batched FC to Conv Added support to convert batched FC to conv. This enables choosing a suitable block-size. Signed-off-by: Patrik Gustavsson Change-Id: Idc49e4fb6d29c554f10a38ece7996a7b7795ffad --- ethosu/vela/graph_optimiser.py | 84 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 82 insertions(+), 2 deletions(-) (limited to 'ethosu/vela/graph_optimiser.py') diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index a57ac82e..f6b03f67 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -329,11 +329,82 @@ def fixup_fully_connected_input(op, arch): desired_shape = [batch_size, n_in_elems] if inp.shape != desired_shape: # mismatch, insert a reshape to fix this. - op.inputs[0] = create_reshape_tensor(inp, desired_shape) + op.set_input_tensor(create_reshape_tensor(inp, desired_shape), 0) return op +def convert_batched_fc_to_conv(op, arch): + if op.type == "FullyConnectedAct": + ifm = op.inputs[0] + ofm = op.outputs[0] + # Check if the FC is 2D and first dimension indicates batching + if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] != 1: + n = ifm.shape[0] + batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)} + h, w = batching_split.get(n, (1, n)) + + # Convert to convolution + op.name += "_conv" + op.type = "Conv2DBiasAct" + faf = op.attrs.get("fused_activation_function", None) + op.attrs = { + "dilation": (1, 1, 1, 1), + "dilation_h_factor": 1, + "dilation_w_factor": 1, + "fused_activation_function": faf, + "npu_block_type": NpuBlockType.ConvolutionMxN, + "padding": b"SAME", + "stride_h": 1, + "stride_w": 1, + "strides": (1, 1, 1, 1), + } + + prev_op = ifm.ops[0] + desired_shape = [1, h, w, ifm.shape[-1]] + if len(ifm.consumer_list) == 1 and prev_op is not None and prev_op.type == "Reshape": + # There is a preceding Reshape + # Compare input of prev_op and input of op, to see if prev_op can be removed + ifm_prev_op = prev_op.inputs[0] + if ifm_prev_op.shape == ifm.shape and ifm_prev_op.quantization.is_scaling_equal(ifm.quantization): + # prev_op can be removed + op.set_input_tensor(ifm_prev_op, 0) + else: + op.inputs[0].set_all_shapes(desired_shape) + prev_op.set_input_tensor( + create_const_tensor(prev_op.inputs[1].name, [1], DataType.int32, desired_shape), 1 + ) + prev_op.attrs["new_shape"] = desired_shape + else: + # Add reshape op to the input if there is no preceding reshape + ifm.consumer_list.remove(op) + op.set_input_tensor(create_reshape_tensor(ifm, desired_shape), 0) + + # Reshape Weights to be 4D. IO becomes HWIO + weight_tensor = op.inputs[1] + weight_tensor.quant_values = np.expand_dims(np.expand_dims(weight_tensor.quant_values, axis=0), axis=0) + weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape)) + + desired_shape = [1, h, w, ofm.shape[-1]] + if ( + len(ofm.consumer_list) == 1 + and ofm.consumer_list[0] is not None + and ofm.consumer_list[0].type == "Reshape" + ): + # There is a subsequent Reshape + # Compare desired shape and output of consumer op, to see if consumer op can be removed + ofm_cons_op = ofm.consumer_list[0].outputs[0] + if desired_shape == ofm_cons_op.shape and ofm.quantization.is_scaling_equal(ofm_cons_op.quantization): + op.outputs[0] = ofm_cons_op + op.outputs[0].ops = [op] + else: + op.outputs[0].set_all_shapes(desired_shape) + else: + # Add rehape op to the output + op.set_output_tensor(create_reshape_tensor(ofm, desired_shape, False)) + return op + + def fixup_pack_input(op, arch): if op.type == "Pack": # Pack is also referred to as Stack @@ -598,10 +669,18 @@ def fixup_act_reorder(op, arch): prep_op = get_prepend_op(op) if prep_op is not None: act_op = op.clone("_reordered") - act_op.inputs = [prep_op.inputs[0]] + + # There is only one input tensor, overwrite it + act_op.set_input_tensor(prep_op.inputs[0], 0) + act_op_out = act_op.inputs[0].clone("_acted") act_op_out.quantization = op.outputs[0].quantization.clone() act_op.set_output_tensor(act_op_out) + + # Update the consumer list + act_op_out.consumer_list = op.outputs[0].consumer_list.copy() + act_op_out.consumer_list.append(prep_op) + prep_op.inputs[0] = act_op_out prep_op.outputs[0].quantization = act_op_out.quantization.clone() @@ -956,6 +1035,7 @@ def optimise_graph_a(nng, arch, verbose_graph=False): convert_conv_to_fc, convert_softmax, fixup_fully_connected_input, + convert_batched_fc_to_conv, fixup_pack_input, fixup_conv2d_backprop, fixup_relus_with_differing_ifm_ofm_scaling, -- cgit v1.2.1