diff options
author | Johan Alfvén <johan.alfven@arm.com> | 2022-09-05 09:39:47 +0200 |
---|---|---|
committer | Johan Alfvén <johan.alfven@arm.com> | 2022-10-19 13:37:45 +0200 |
commit | 673683bb828cd552f1970922e3c61079607332b2 (patch) | |
tree | 02e6ca41621ca7ec32d7eb6f36cb755b8da14963 /ethosu/vela/tflite_reader.py | |
parent | d3d81b3ce138a48c0cddad7eb12710e26dad653e (diff) | |
download | ethos-u-vela-673683bb828cd552f1970922e3c61079607332b2.tar.gz |
MLBEDSW-6880: Add support for multiple subgraphs
- Vela failed to compile networks with multiple subgraphs because
only cascaded passes in the root subgraph were used when
extracting the live ranges. The fix is to extract the subgraph
range live on Ops that have connected subgraphs.
- The tf_writer did not handle multiple subgraphs in a correct way
resulting in corrupt buffer data in the optimized tflite file. The buffer
index must be unique for every tensor.
-Added support to handle multiple subgraphs for the OfflineMemoryAllocation
meta data. The change will not change behavior for single graphs.
Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Change-Id: I2328dfc1f07e2e4faf43a75423ea95423096ffa3
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r-- | ethosu/vela/tflite_reader.py | 26 |
1 files changed, 21 insertions, 5 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index 8dc5efe1..fa90ad9e 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved. +# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 # @@ -147,6 +147,15 @@ class TFLiteSubgraph: if opt_serializer is not None: op.attrs = opt_serializer.deserialize(op_data) + if op_type == Op.While: + # Attach the actual nng subgraphs to the op + cond_subgraph_index = op.attrs["cond_subgraph_index"] + body_subgraph_index = op.attrs["body_subgraph_index"] + op.attrs["subgraph"] = ( + self.graph.nng.subgraphs[cond_subgraph_index], + self.graph.nng.subgraphs[body_subgraph_index], + ) + if op_type == Op.Reshape and "new_shape" not in op.attrs: # Reshape should have an attrib "new_shape" but if it is missing, add it based on the output shape op.attrs["new_shape"] = outputs[0].shape @@ -223,16 +232,23 @@ class TFLiteGraph: parsing_step = "parsing subgraphs length" self.subgraphs = [] + + # Pre-allocate nng subgraphs - needed when parsing an operator and the operator + # has subgraph attributes. + self.nng = Graph(self.name, self.batch_size) + for idx in range(model.SubgraphsLength()): + sg = Subgraph() + self.nng.subgraphs.append(sg) + for idx in range(model.SubgraphsLength()): parsing_step = f"parsing subgraph {idx}" self.subgraphs.append(TFLiteSubgraph(self, model.Subgraphs(idx))) - self.nng = Graph(self.name, self.batch_size) - for tflite_sg in self.subgraphs: - sg = Subgraph(tflite_sg.name) + for idx, tflite_sg in enumerate(self.subgraphs): + sg = self.nng.subgraphs[idx] + sg.name = tflite_sg.name sg.original_inputs = tflite_sg.inputs # Preserve the original input order sg.output_tensors = tflite_sg.outputs - self.nng.subgraphs.append(sg) parsing_step = "parsing metadata length" # Preserve the original metadata |