diff options
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r-- | ethosu/vela/tflite_reader.py | 26 |
1 files changed, 25 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index 80f36457..2325ff65 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> +# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com> # # SPDX-License-Identifier: Apache-2.0 # @@ -32,6 +32,7 @@ from .reader_util import align_tensor_indices_to_nng from .reader_util import clone_and_reshape_tensor from .reader_util import decode_str from .reader_util import fixup_tensors +from .tensor import create_virtual_tensor from .tensor import QuantizationParameters from .tensor import Tensor from .tflite.BuiltinOperator import BuiltinOperator @@ -51,6 +52,7 @@ class TFLiteSubgraph: for idx in range(subgraph.TensorsLength()): self.tensors.append(self.parse_tensor(subgraph.Tensors(idx))) + self.virtual_outputs = [] for idx in range(subgraph.OperatorsLength()): self.parse_operator(idx, subgraph.Operators(idx)) @@ -58,6 +60,8 @@ class TFLiteSubgraph: self.inputs = self.get_tensors_from_indices_remove_duplicates(subgraph.InputsAsNumpy(), "input") fixup_tensors(self.inputs, self.tensors) + self.outputs.extend(self.virtual_outputs) + def get_tensors_from_indices_remove_duplicates(self, indices, warning_str): tensors = [] for idx in indices: @@ -131,6 +135,21 @@ class TFLiteSubgraph: for out in op.outputs: out.ops = [op] + if op_type in (Op.AssignVariable, Op.CallOnce): + # All graph traversals are based on depth-first and the starting + # points are the subgraph output tensors. Because of this, operators + # like AssignVariable and CallOnce will not be visit when the + # graph is traversed and the ops are never handled. In order to + # fix that, the code base will have to be changed in several places. + # Until then this workaround is applied. A virtual output is added + # both to the operator and to the subgraph. By doing this the full + # graph is traversed correctly. The tensor is not used for anything + # else. + op.name = f"{op_type}_{op_index}" + tens = create_virtual_tensor(op.name) + op.set_output_tensor(tens) + self.virtual_outputs.append(tens) + if op.type.is_depthwise_conv2d_op() or op.type.is_conv2d_op() or op.type == Op.FullyConnected: if inputs[1].values is not None: if op.type == Op.FullyConnected: @@ -156,6 +175,10 @@ class TFLiteSubgraph: self.graph.nng.subgraphs[cond_subgraph_index], self.graph.nng.subgraphs[body_subgraph_index], ) + if op_type == Op.CallOnce: + # Attach the actual nng subgraphs to the op + init_subgraph_index = op.attrs["init_subgraph_index"] + op.attrs["subgraph"] = (self.graph.nng.subgraphs[init_subgraph_index],) if op_type == Op.Reshape and "new_shape" not in op.attrs: # Reshape should have an attrib "new_shape" but if it is missing, add it based on the output shape @@ -250,6 +273,7 @@ class TFLiteGraph: sg.name = tflite_sg.name sg.original_inputs = tflite_sg.inputs # Preserve the original input order sg.output_tensors = tflite_sg.outputs + sg.virtual_outputs = tflite_sg.virtual_outputs parsing_step = "parsing metadata length" # Preserve the original metadata |