aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py57
1 files changed, 45 insertions, 12 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index fb8a08c0..88d58a32 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -512,7 +512,10 @@ def add_padding_fields(op, arch, nng):
)
else:
padding, skirt = calc_padding_and_skirt(
- op.attrs["padding"], op.kernel, input_shape, op.attrs.get("explicit_padding"),
+ op.attrs["padding"],
+ op.kernel,
+ input_shape,
+ op.attrs.get("explicit_padding"),
)
op.attrs["explicit_padding"] = padding
@@ -642,11 +645,11 @@ def convert_softmax(op, arch, nng):
def convert_mul_max_to_abs_or_lrelu(op, arch, nng):
r"""Whenever there is a subgraph with this topology:
- Input X For X = -1 or X > 0
- | \ / This subgraph can be replaced with either
- | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
- | /
- Max
+ Input X For X = -1 or X > 0
+ | \ / This subgraph can be replaced with either
+ | Mul an Abs (if X = -1) or a LeakyReLU (if X > 0)
+ | /
+ Max
"""
if op.type == Op.Maximum:
@@ -1246,7 +1249,12 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
quant = QuantizationParameters()
quant.zero_point = 0
bias_term_tens = create_const_tensor(
- op.name + "_bias", [1, 1, 1, 1], DataType.int16, [bias_term], np.int16, quantization=quant,
+ op.name + "_bias",
+ [1, 1, 1, 1],
+ DataType.int16,
+ [bias_term],
+ np.int16,
+ quantization=quant,
)
add_op.add_input_tensor(bias_term_tens)
add_op.set_output_tensor(op.ofm)
@@ -1370,7 +1378,12 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
bias_shape = [shape[-1]]
op.set_input_tensor(
create_const_tensor(
- "bias", bias_shape, inp.dtype, np.ones(bias_shape) * bias, value_dtype=np.int32, quantization=None,
+ "bias",
+ bias_shape,
+ inp.dtype,
+ np.ones(bias_shape) * bias,
+ value_dtype=np.int32,
+ quantization=None,
),
2,
)
@@ -1392,7 +1405,12 @@ def tflite_optimise_graph(nng, arch):
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
+ nng,
+ sg,
+ arch,
+ [],
+ pre_process_list,
+ rewrite_unsupported=False,
)
# Handle Concat Ops
@@ -1413,13 +1431,23 @@ def tflite_optimise_graph(nng, arch):
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [rewrite_split_ops], [], rewrite_unsupported=False,
+ nng,
+ sg,
+ arch,
+ [rewrite_split_ops],
+ [],
+ rewrite_unsupported=False,
)
# Handle sg input output
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False,
+ nng,
+ sg,
+ arch,
+ [],
+ [fix_sg_input_output],
+ rewrite_unsupported=False,
)
# Removal of memory only operators
@@ -1452,7 +1480,12 @@ def tflite_optimise_graph(nng, arch):
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
+ nng,
+ sg,
+ arch,
+ [],
+ op_rewrite_list,
+ rewrite_unsupported=False,
)
for idx, sg in enumerate(nng.subgraphs):