aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_reader.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r--ethosu/vela/tflite_reader.py26
1 files changed, 21 insertions, 5 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 8dc5efe1..fa90ad9e 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -147,6 +147,15 @@ class TFLiteSubgraph:
if opt_serializer is not None:
op.attrs = opt_serializer.deserialize(op_data)
+ if op_type == Op.While:
+ # Attach the actual nng subgraphs to the op
+ cond_subgraph_index = op.attrs["cond_subgraph_index"]
+ body_subgraph_index = op.attrs["body_subgraph_index"]
+ op.attrs["subgraph"] = (
+ self.graph.nng.subgraphs[cond_subgraph_index],
+ self.graph.nng.subgraphs[body_subgraph_index],
+ )
+
if op_type == Op.Reshape and "new_shape" not in op.attrs:
# Reshape should have an attrib "new_shape" but if it is missing, add it based on the output shape
op.attrs["new_shape"] = outputs[0].shape
@@ -223,16 +232,23 @@ class TFLiteGraph:
parsing_step = "parsing subgraphs length"
self.subgraphs = []
+
+ # Pre-allocate nng subgraphs - needed when parsing an operator and the operator
+ # has subgraph attributes.
+ self.nng = Graph(self.name, self.batch_size)
+ for idx in range(model.SubgraphsLength()):
+ sg = Subgraph()
+ self.nng.subgraphs.append(sg)
+
for idx in range(model.SubgraphsLength()):
parsing_step = f"parsing subgraph {idx}"
self.subgraphs.append(TFLiteSubgraph(self, model.Subgraphs(idx)))
- self.nng = Graph(self.name, self.batch_size)
- for tflite_sg in self.subgraphs:
- sg = Subgraph(tflite_sg.name)
+ for idx, tflite_sg in enumerate(self.subgraphs):
+ sg = self.nng.subgraphs[idx]
+ sg.name = tflite_sg.name
sg.original_inputs = tflite_sg.inputs # Preserve the original input order
sg.output_tensors = tflite_sg.outputs
- self.nng.subgraphs.append(sg)
parsing_step = "parsing metadata length"
# Preserve the original metadata