aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_graph_optimiser.py')
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py41
1 files changed, 41 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 06395784..cf3985e4 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -1391,12 +1391,43 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
return op
+def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
+ """Static optimisation for SHAPE operator output value known at compile time"""
+
+ # Disconnect SHAPE operator from its parent and transform SHAPE OP into constant
+
+ if op.type == Op.Shape and op.run_on_npu:
+
+ ifm, ofm = op.get_ifm_ofm()
+
+ if len(ifm.shape) != ofm.shape[0]:
+ return op
+
+ # Remove reference of the current shape op from the parent tensor's consumer list
+ ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
+
+ # Clear any references to parent node
+ op.inputs = []
+
+ # Convert this SHAPE op to const
+ op.type = Op.Const
+
+ # Add size calculation to shape output tensors
+ ofm.values = np.array(ifm.shape)
+
+ return op
+
+
def supported_operator_check(op, arch, nng):
op.run_on_npu = arch.tflite_supported_operators.is_operator_supported(op)
return op
def tflite_optimise_graph(nng, arch):
+
+ # Compile time optimisations
+ optimisation_list = [convert_shape_op_to_constant_tensor]
+
# Pre-processing step
pre_process_list = [
supported_operator_check,
@@ -1409,6 +1440,16 @@ def tflite_optimise_graph(nng, arch):
sg,
arch,
[],
+ optimisation_list,
+ rewrite_unsupported=False,
+ )
+
+ for idx, sg in enumerate(nng.subgraphs):
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+ nng,
+ sg,
+ arch,
+ [],
pre_process_list,
rewrite_unsupported=False,
)