aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py126
1 files changed, 125 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 31d3ae1a..c7fe6cd9 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -2423,6 +2423,130 @@ def fixup_reshape(op, arch, nng):
return op
+def convert_conv_groups(op: Operation, arch, nng):
+ """
+ Convert convolution groups to a split followed by separate convolutions and then a concat.
+ This needs to run before the concat and split handling functions"""
+ if not op.type.is_conv2d_op():
+ return op
+
+ num_conv_groups = op.attrs.get("num_conv_groups", 0)
+ if num_conv_groups > 1:
+ # convolution groups params
+ ifm_depth_cg = op.ifm.shape[-1] // num_conv_groups
+ num_filters_cg = op.weights.shape[-1] // num_conv_groups
+
+ # create split
+ split_op = Operation(Op.Split, f"{op.name}_split")
+ split_op.attrs.update(
+ {
+ "num_splits": num_conv_groups,
+ }
+ )
+ # first input is the split axis
+ split_op.add_input_tensor(
+ # split along the depth axis
+ create_const_tensor(f"{split_op.name}_axis", [0], DataType.int32, [-1])
+ )
+ # second input is the ifm
+ split_op.add_input_tensor(op.ifm)
+ # calculate shape of each ofm part
+ split_op_ofm_shape = op.ifm.shape[:-1] + [ifm_depth_cg]
+
+ # create concat. do this prior to each conv group so that the for-loop can reference the concat as it iterates
+ concat_op = Operation(Op.ConcatTFLite, f"{op.name}_concat")
+ concat_op.attrs.update(
+ {
+ "axis": -1,
+ "fused_activation_function": None,
+ }
+ )
+ # calculate shape of each ifm part
+ concat_op_ifm_shape = op.ofm.shape[:-1] + [num_filters_cg]
+ # output is the concatenated tensor
+ concat_op.set_output_tensor(op.ofm) # will disconnect ofm from op
+
+ # for each conv group
+ for i in range(num_conv_groups):
+ # cg params
+ cg_oc_start = i * num_filters_cg
+ cg_oc_end = (i + 1) * num_filters_cg
+
+ # split has multiple outputs
+ split_op_ofm_part = Tensor(split_op_ofm_shape, op.ifm.dtype, f"{split_op.name}_out{i}")
+ split_op_ofm_part.quantization = op.ifm.quantization.clone()
+ split_op.add_output_tensor(split_op_ofm_part)
+
+ # concat has multiple inputs
+ concat_op_ifm_part = Tensor(concat_op_ifm_shape, op.ifm.dtype, f"{concat_op.name}_in{i}")
+ concat_op_ifm_part.quantization = op.ofm.quantization.clone()
+ concat_op.add_input_tensor(concat_op_ifm_part)
+
+ # create convolution group operator
+ conv_group_op = Operation(op.type, f"{op.name}_cg{i}")
+ conv_group_op.attrs = op.attrs.copy()
+ conv_group_op.attrs["num_conv_groups"] = 1
+ # first input is the ifm
+ conv_group_op.add_input_tensor(split_op_ofm_part)
+ # second input is weights. the number of filters (i.e. the output channels) need to be split equally
+ # across all of the convolution groups
+ conv_group_op_weights_shape = op.weights.shape[:-1] + [num_filters_cg]
+ conv_group_op_weights_quant = op.weights.quantization.clone()
+ conv_group_op_weights_quant.scale_f32 = op.weights.quantization.scale_f32[..., cg_oc_start:cg_oc_end]
+ conv_group_op_weights_quant.zero_point = op.weights.quantization.zero_point[..., cg_oc_start:cg_oc_end]
+ conv_group_op.add_input_tensor(
+ create_const_tensor(
+ f"{op.weights.name}_cg{i}",
+ conv_group_op_weights_shape,
+ op.weights.dtype,
+ op.weights.values[..., cg_oc_start:cg_oc_end],
+ op.weights.purpose,
+ conv_group_op_weights_quant,
+ )
+ )
+ # third input is bias. like the weights, the bias needs to be split equally across all of the convolution
+ # groups
+ if op.bias is None:
+ conv_group_op.add_input_tensor(None)
+ else:
+ conv_group_op_bias_shape = op.bias.shape[:-1] + [num_filters_cg]
+ conv_group_op_bias_quant = op.bias.quantization.clone()
+ conv_group_op_bias_quant.scale_f32 = op.bias.quantization.scale_f32[..., cg_oc_start:cg_oc_end]
+ conv_group_op_bias_quant.zero_point = op.bias.quantization.zero_point[..., cg_oc_start:cg_oc_end]
+ conv_group_op.add_input_tensor(
+ create_const_tensor(
+ f"{op.bias.name}_cg{i}",
+ conv_group_op_bias_shape,
+ op.bias.dtype,
+ op.bias.values[..., cg_oc_start:cg_oc_end],
+ op.bias.purpose,
+ op.bias.quantization,
+ )
+ )
+ # output goes to the concat
+ conv_group_op.set_output_tensor(concat_op_ifm_part)
+ # update the cg op shapes and debug db
+ conv_group_op.set_ifm_ofm_shapes()
+ DebugDatabase.add_optimised(op, conv_group_op)
+
+ # update the split/concat op shapes/debug db
+ split_op.set_ifm_ofm_shapes()
+ DebugDatabase.add_optimised(op, split_op)
+ concat_op.set_ifm_ofm_shapes()
+ DebugDatabase.add_optimised(op, concat_op)
+
+ # disconnect the original convolution operator.
+ # the ofm has already been disconnected by concat_op.set_output_tensor()
+ op.ifm.consumer_list.remove(op)
+ op.inputs = []
+ op.outputs = []
+
+ # return last op so that other graph optimiser functions can process the new operators
+ op = concat_op
+
+ return op
+
+
def supported_operator_check(op, arch, nng):
op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
return op
@@ -2447,7 +2571,7 @@ def tflite_optimise_graph(nng, arch, force_symmetric_int_weights):
)
# Pre-processing step
- pre_process_list = [supported_operator_check, set_ifm_ofm_op_shapes, fixup_reshape]
+ pre_process_list = [supported_operator_check, set_ifm_ofm_op_shapes, fixup_reshape, convert_conv_groups]
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(