diff options
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r-- | ethosu/vela/tflite_reader.py | 26 |
1 files changed, 21 insertions, 5 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index 8dc5efe1..fa90ad9e 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved. +# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 # @@ -147,6 +147,15 @@ class TFLiteSubgraph: if opt_serializer is not None: op.attrs = opt_serializer.deserialize(op_data) + if op_type == Op.While: + # Attach the actual nng subgraphs to the op + cond_subgraph_index = op.attrs["cond_subgraph_index"] + body_subgraph_index = op.attrs["body_subgraph_index"] + op.attrs["subgraph"] = ( + self.graph.nng.subgraphs[cond_subgraph_index], + self.graph.nng.subgraphs[body_subgraph_index], + ) + if op_type == Op.Reshape and "new_shape" not in op.attrs: # Reshape should have an attrib "new_shape" but if it is missing, add it based on the output shape op.attrs["new_shape"] = outputs[0].shape @@ -223,16 +232,23 @@ class TFLiteGraph: parsing_step = "parsing subgraphs length" self.subgraphs = [] + + # Pre-allocate nng subgraphs - needed when parsing an operator and the operator + # has subgraph attributes. + self.nng = Graph(self.name, self.batch_size) + for idx in range(model.SubgraphsLength()): + sg = Subgraph() + self.nng.subgraphs.append(sg) + for idx in range(model.SubgraphsLength()): parsing_step = f"parsing subgraph {idx}" self.subgraphs.append(TFLiteSubgraph(self, model.Subgraphs(idx))) - self.nng = Graph(self.name, self.batch_size) - for tflite_sg in self.subgraphs: - sg = Subgraph(tflite_sg.name) + for idx, tflite_sg in enumerate(self.subgraphs): + sg = self.nng.subgraphs[idx] + sg.name = tflite_sg.name sg.original_inputs = tflite_sg.inputs # Preserve the original input order sg.output_tensors = tflite_sg.outputs - self.nng.subgraphs.append(sg) parsing_step = "parsing metadata length" # Preserve the original metadata |