aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_reader.py
diff options
context:
space:
mode:
authorJohan Alfvén <johan.alfven@arm.com>2022-09-05 09:39:47 +0200
committerJohan Alfvén <johan.alfven@arm.com>2022-10-19 13:37:45 +0200
commit673683bb828cd552f1970922e3c61079607332b2 (patch)
tree02e6ca41621ca7ec32d7eb6f36cb755b8da14963 /ethosu/vela/tflite_reader.py
parentd3d81b3ce138a48c0cddad7eb12710e26dad653e (diff)
downloadethos-u-vela-673683bb828cd552f1970922e3c61079607332b2.tar.gz
MLBEDSW-6880: Add support for multiple subgraphs
- Vela failed to compile networks with multiple subgraphs because only cascaded passes in the root subgraph were used when extracting the live ranges. The fix is to extract the subgraph range live on Ops that have connected subgraphs. - The tf_writer did not handle multiple subgraphs in a correct way resulting in corrupt buffer data in the optimized tflite file. The buffer index must be unique for every tensor. -Added support to handle multiple subgraphs for the OfflineMemoryAllocation meta data. The change will not change behavior for single graphs. Signed-off-by: Johan Alfven <johan.alfven@arm.com> Change-Id: I2328dfc1f07e2e4faf43a75423ea95423096ffa3
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r--ethosu/vela/tflite_reader.py26
1 files changed, 21 insertions, 5 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 8dc5efe1..fa90ad9e 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -147,6 +147,15 @@ class TFLiteSubgraph:
if opt_serializer is not None:
op.attrs = opt_serializer.deserialize(op_data)
+ if op_type == Op.While:
+ # Attach the actual nng subgraphs to the op
+ cond_subgraph_index = op.attrs["cond_subgraph_index"]
+ body_subgraph_index = op.attrs["body_subgraph_index"]
+ op.attrs["subgraph"] = (
+ self.graph.nng.subgraphs[cond_subgraph_index],
+ self.graph.nng.subgraphs[body_subgraph_index],
+ )
+
if op_type == Op.Reshape and "new_shape" not in op.attrs:
# Reshape should have an attrib "new_shape" but if it is missing, add it based on the output shape
op.attrs["new_shape"] = outputs[0].shape
@@ -223,16 +232,23 @@ class TFLiteGraph:
parsing_step = "parsing subgraphs length"
self.subgraphs = []
+
+ # Pre-allocate nng subgraphs - needed when parsing an operator and the operator
+ # has subgraph attributes.
+ self.nng = Graph(self.name, self.batch_size)
+ for idx in range(model.SubgraphsLength()):
+ sg = Subgraph()
+ self.nng.subgraphs.append(sg)
+
for idx in range(model.SubgraphsLength()):
parsing_step = f"parsing subgraph {idx}"
self.subgraphs.append(TFLiteSubgraph(self, model.Subgraphs(idx)))
- self.nng = Graph(self.name, self.batch_size)
- for tflite_sg in self.subgraphs:
- sg = Subgraph(tflite_sg.name)
+ for idx, tflite_sg in enumerate(self.subgraphs):
+ sg = self.nng.subgraphs[idx]
+ sg.name = tflite_sg.name
sg.original_inputs = tflite_sg.inputs # Preserve the original input order
sg.output_tensors = tflite_sg.outputs
- self.nng.subgraphs.append(sg)
parsing_step = "parsing metadata length"
# Preserve the original metadata