diff options
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r-- | ethosu/vela/tflite_graph_optimiser.py | 57 |
1 files changed, 45 insertions, 12 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py index fb8a08c0..88d58a32 100644 --- a/ethosu/vela/tflite_graph_optimiser.py +++ b/ethosu/vela/tflite_graph_optimiser.py @@ -512,7 +512,10 @@ def add_padding_fields(op, arch, nng): ) else: padding, skirt = calc_padding_and_skirt( - op.attrs["padding"], op.kernel, input_shape, op.attrs.get("explicit_padding"), + op.attrs["padding"], + op.kernel, + input_shape, + op.attrs.get("explicit_padding"), ) op.attrs["explicit_padding"] = padding @@ -642,11 +645,11 @@ def convert_softmax(op, arch, nng): def convert_mul_max_to_abs_or_lrelu(op, arch, nng): r"""Whenever there is a subgraph with this topology: - Input X For X = -1 or X > 0 - | \ / This subgraph can be replaced with either - | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0) - | / - Max + Input X For X = -1 or X > 0 + | \ / This subgraph can be replaced with either + | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0) + | / + Max """ if op.type == Op.Maximum: @@ -1246,7 +1249,12 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng): quant = QuantizationParameters() quant.zero_point = 0 bias_term_tens = create_const_tensor( - op.name + "_bias", [1, 1, 1, 1], DataType.int16, [bias_term], np.int16, quantization=quant, + op.name + "_bias", + [1, 1, 1, 1], + DataType.int16, + [bias_term], + np.int16, + quantization=quant, ) add_op.add_input_tensor(bias_term_tens) add_op.set_output_tensor(op.ofm) @@ -1370,7 +1378,12 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng): bias_shape = [shape[-1]] op.set_input_tensor( create_const_tensor( - "bias", bias_shape, inp.dtype, np.ones(bias_shape) * bias, value_dtype=np.int32, quantization=None, + "bias", + bias_shape, + inp.dtype, + np.ones(bias_shape) * bias, + value_dtype=np.int32, + quantization=None, ), 2, ) @@ -1392,7 +1405,12 @@ def tflite_optimise_graph(nng, arch): for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [], pre_process_list, rewrite_unsupported=False, + nng, + sg, + arch, + [], + pre_process_list, + rewrite_unsupported=False, ) # Handle Concat Ops @@ -1413,13 +1431,23 @@ def tflite_optimise_graph(nng, arch): for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [rewrite_split_ops], [], rewrite_unsupported=False, + nng, + sg, + arch, + [rewrite_split_ops], + [], + rewrite_unsupported=False, ) # Handle sg input output for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False, + nng, + sg, + arch, + [], + [fix_sg_input_output], + rewrite_unsupported=False, ) # Removal of memory only operators @@ -1452,7 +1480,12 @@ def tflite_optimise_graph(nng, arch): for idx, sg in enumerate(nng.subgraphs): nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order( - nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False, + nng, + sg, + arch, + [], + op_rewrite_list, + rewrite_unsupported=False, ) for idx, sg in enumerate(nng.subgraphs): |