aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_reader.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r--ethosu/vela/tflite_reader.py26
1 files changed, 25 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 80f36457..2325ff65 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -32,6 +32,7 @@ from .reader_util import align_tensor_indices_to_nng
from .reader_util import clone_and_reshape_tensor
from .reader_util import decode_str
from .reader_util import fixup_tensors
+from .tensor import create_virtual_tensor
from .tensor import QuantizationParameters
from .tensor import Tensor
from .tflite.BuiltinOperator import BuiltinOperator
@@ -51,6 +52,7 @@ class TFLiteSubgraph:
for idx in range(subgraph.TensorsLength()):
self.tensors.append(self.parse_tensor(subgraph.Tensors(idx)))
+ self.virtual_outputs = []
for idx in range(subgraph.OperatorsLength()):
self.parse_operator(idx, subgraph.Operators(idx))
@@ -58,6 +60,8 @@ class TFLiteSubgraph:
self.inputs = self.get_tensors_from_indices_remove_duplicates(subgraph.InputsAsNumpy(), "input")
fixup_tensors(self.inputs, self.tensors)
+ self.outputs.extend(self.virtual_outputs)
+
def get_tensors_from_indices_remove_duplicates(self, indices, warning_str):
tensors = []
for idx in indices:
@@ -131,6 +135,21 @@ class TFLiteSubgraph:
for out in op.outputs:
out.ops = [op]
+ if op_type in (Op.AssignVariable, Op.CallOnce):
+ # All graph traversals are based on depth-first and the starting
+ # points are the subgraph output tensors. Because of this, operators
+ # like AssignVariable and CallOnce will not be visit when the
+ # graph is traversed and the ops are never handled. In order to
+ # fix that, the code base will have to be changed in several places.
+ # Until then this workaround is applied. A virtual output is added
+ # both to the operator and to the subgraph. By doing this the full
+ # graph is traversed correctly. The tensor is not used for anything
+ # else.
+ op.name = f"{op_type}_{op_index}"
+ tens = create_virtual_tensor(op.name)
+ op.set_output_tensor(tens)
+ self.virtual_outputs.append(tens)
+
if op.type.is_depthwise_conv2d_op() or op.type.is_conv2d_op() or op.type == Op.FullyConnected:
if inputs[1].values is not None:
if op.type == Op.FullyConnected:
@@ -156,6 +175,10 @@ class TFLiteSubgraph:
self.graph.nng.subgraphs[cond_subgraph_index],
self.graph.nng.subgraphs[body_subgraph_index],
)
+ if op_type == Op.CallOnce:
+ # Attach the actual nng subgraphs to the op
+ init_subgraph_index = op.attrs["init_subgraph_index"]
+ op.attrs["subgraph"] = (self.graph.nng.subgraphs[init_subgraph_index],)
if op_type == Op.Reshape and "new_shape" not in op.attrs:
# Reshape should have an attrib "new_shape" but if it is missing, add it based on the output shape
@@ -250,6 +273,7 @@ class TFLiteGraph:
sg.name = tflite_sg.name
sg.original_inputs = tflite_sg.inputs # Preserve the original input order
sg.output_tensors = tflite_sg.outputs
+ sg.virtual_outputs = tflite_sg.virtual_outputs
parsing_step = "parsing metadata length"
# Preserve the original metadata