diff options
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 126 |
1 files changed, 125 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index 31d3ae1a..c7fe6cd9 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -2423,6 +2423,130 @@ def fixup_reshape(op, arch, nng): return op +def convert_conv_groups(op: Operation, arch, nng): + """ + Convert convolution groups to a split followed by separate convolutions and then a concat. + This needs to run before the concat and split handling functions""" + if not op.type.is_conv2d_op(): + return op + + num_conv_groups = op.attrs.get("num_conv_groups", 0) + if num_conv_groups > 1: + # convolution groups params + ifm_depth_cg = op.ifm.shape[-1] // num_conv_groups + num_filters_cg = op.weights.shape[-1] // num_conv_groups + + # create split + split_op = Operation(Op.Split, f"{op.name}_split") + split_op.attrs.update( + { + "num_splits": num_conv_groups, + } + ) + # first input is the split axis + split_op.add_input_tensor( + # split along the depth axis + create_const_tensor(f"{split_op.name}_axis", [0], DataType.int32, [-1]) + ) + # second input is the ifm + split_op.add_input_tensor(op.ifm) + # calculate shape of each ofm part + split_op_ofm_shape = op.ifm.shape[:-1] + [ifm_depth_cg] + + # create concat. do this prior to each conv group so that the for-loop can reference the concat as it iterates + concat_op = Operation(Op.ConcatTFLite, f"{op.name}_concat") + concat_op.attrs.update( + { + "axis": -1, + "fused_activation_function": None, + } + ) + # calculate shape of each ifm part + concat_op_ifm_shape = op.ofm.shape[:-1] + [num_filters_cg] + # output is the concatenated tensor + concat_op.set_output_tensor(op.ofm) # will disconnect ofm from op + + # for each conv group + for i in range(num_conv_groups): + # cg params + cg_oc_start = i * num_filters_cg + cg_oc_end = (i + 1) * num_filters_cg + + # split has multiple outputs + split_op_ofm_part = Tensor(split_op_ofm_shape, op.ifm.dtype, f"{split_op.name}_out{i}") + split_op_ofm_part.quantization = op.ifm.quantization.clone() + split_op.add_output_tensor(split_op_ofm_part) + + # concat has multiple inputs + concat_op_ifm_part = Tensor(concat_op_ifm_shape, op.ifm.dtype, f"{concat_op.name}_in{i}") + concat_op_ifm_part.quantization = op.ofm.quantization.clone() + concat_op.add_input_tensor(concat_op_ifm_part) + + # create convolution group operator + conv_group_op = Operation(op.type, f"{op.name}_cg{i}") + conv_group_op.attrs = op.attrs.copy() + conv_group_op.attrs["num_conv_groups"] = 1 + # first input is the ifm + conv_group_op.add_input_tensor(split_op_ofm_part) + # second input is weights. the number of filters (i.e. the output channels) need to be split equally + # across all of the convolution groups + conv_group_op_weights_shape = op.weights.shape[:-1] + [num_filters_cg] + conv_group_op_weights_quant = op.weights.quantization.clone() + conv_group_op_weights_quant.scale_f32 = op.weights.quantization.scale_f32[..., cg_oc_start:cg_oc_end] + conv_group_op_weights_quant.zero_point = op.weights.quantization.zero_point[..., cg_oc_start:cg_oc_end] + conv_group_op.add_input_tensor( + create_const_tensor( + f"{op.weights.name}_cg{i}", + conv_group_op_weights_shape, + op.weights.dtype, + op.weights.values[..., cg_oc_start:cg_oc_end], + op.weights.purpose, + conv_group_op_weights_quant, + ) + ) + # third input is bias. like the weights, the bias needs to be split equally across all of the convolution + # groups + if op.bias is None: + conv_group_op.add_input_tensor(None) + else: + conv_group_op_bias_shape = op.bias.shape[:-1] + [num_filters_cg] + conv_group_op_bias_quant = op.bias.quantization.clone() + conv_group_op_bias_quant.scale_f32 = op.bias.quantization.scale_f32[..., cg_oc_start:cg_oc_end] + conv_group_op_bias_quant.zero_point = op.bias.quantization.zero_point[..., cg_oc_start:cg_oc_end] + conv_group_op.add_input_tensor( + create_const_tensor( + f"{op.bias.name}_cg{i}", + conv_group_op_bias_shape, + op.bias.dtype, + op.bias.values[..., cg_oc_start:cg_oc_end], + op.bias.purpose, + op.bias.quantization, + ) + ) + # output goes to the concat + conv_group_op.set_output_tensor(concat_op_ifm_part) + # update the cg op shapes and debug db + conv_group_op.set_ifm_ofm_shapes() + DebugDatabase.add_optimised(op, conv_group_op) + + # update the split/concat op shapes/debug db + split_op.set_ifm_ofm_shapes() + DebugDatabase.add_optimised(op, split_op) + concat_op.set_ifm_ofm_shapes() + DebugDatabase.add_optimised(op, concat_op) + + # disconnect the original convolution operator. + # the ofm has already been disconnected by concat_op.set_output_tensor() + op.ifm.consumer_list.remove(op) + op.inputs = [] + op.outputs = [] + + # return last op so that other graph optimiser functions can process the new operators + op = concat_op + + return op + + def supported_operator_check(op, arch, nng): op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op) return op @@ -2447,7 +2571,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights): ) # Pre-processing step - pre_process_list = [supported_operator_check, set_ifm_ofm_op_shapes, fixup_reshape] + pre_process_list = [supported_operator_check, set_ifm_ofm_op_shapes, fixup_reshape, convert_conv_groups] for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( |