aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_reader.py
diff options
context:
space:
mode:
authorJohan Alfven <johan.alfven@arm.com>2023-02-07 13:01:03 +0100
committerJohan Alfven <johan.alfven@arm.com>2023-02-14 13:29:26 +0100
commit9070f0f1d9ee0fbf2cc3ee62a60f9b600bd62055 (patch)
tree01ad61e2a33c9b976de53656743a24369ccc8119 /ethosu/vela/tflite_reader.py
parent33c01e68984bf455d3a1f00c7f43ab2a6bb75cbe (diff)
downloadethos-u-vela-9070f0f1d9ee0fbf2cc3ee62a60f9b600bd62055.tar.gz
MLBEDSW-7316: Fix crash for networks with resource variables
- The problem was that networks with resource variables have not been thought of. The major problem was the graph traversal where these ops were not visited resulting in an empty subgraph that resulted in the crash. - Fixed the problem by attaching virtual tensors to the ops simulating subgraph output. These tensors are only used to get the graph traversal to work. - Fixed serializing of attribute container and shared_name - Fixed subgraph index for operator CallOnce - All resource variable ops are pushed to the CPU Change-Id: I815f9c81baf7a3fbb686e895980b462f58208b6e Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r--ethosu/vela/tflite_reader.py26
1 files changed, 25 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 80f36457..2325ff65 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -32,6 +32,7 @@ from .reader_util import align_tensor_indices_to_nng
from .reader_util import clone_and_reshape_tensor
from .reader_util import decode_str
from .reader_util import fixup_tensors
+from .tensor import create_virtual_tensor
from .tensor import QuantizationParameters
from .tensor import Tensor
from .tflite.BuiltinOperator import BuiltinOperator
@@ -51,6 +52,7 @@ class TFLiteSubgraph:
for idx in range(subgraph.TensorsLength()):
self.tensors.append(self.parse_tensor(subgraph.Tensors(idx)))
+ self.virtual_outputs = []
for idx in range(subgraph.OperatorsLength()):
self.parse_operator(idx, subgraph.Operators(idx))
@@ -58,6 +60,8 @@ class TFLiteSubgraph:
self.inputs = self.get_tensors_from_indices_remove_duplicates(subgraph.InputsAsNumpy(), "input")
fixup_tensors(self.inputs, self.tensors)
+ self.outputs.extend(self.virtual_outputs)
+
def get_tensors_from_indices_remove_duplicates(self, indices, warning_str):
tensors = []
for idx in indices:
@@ -131,6 +135,21 @@ class TFLiteSubgraph:
for out in op.outputs:
out.ops = [op]
+ if op_type in (Op.AssignVariable, Op.CallOnce):
+ # All graph traversals are based on depth-first and the starting
+ # points are the subgraph output tensors. Because of this, operators
+ # like AssignVariable and CallOnce will not be visit when the
+ # graph is traversed and the ops are never handled. In order to
+ # fix that, the code base will have to be changed in several places.
+ # Until then this workaround is applied. A virtual output is added
+ # both to the operator and to the subgraph. By doing this the full
+ # graph is traversed correctly. The tensor is not used for anything
+ # else.
+ op.name = f"{op_type}_{op_index}"
+ tens = create_virtual_tensor(op.name)
+ op.set_output_tensor(tens)
+ self.virtual_outputs.append(tens)
+
if op.type.is_depthwise_conv2d_op() or op.type.is_conv2d_op() or op.type == Op.FullyConnected:
if inputs[1].values is not None:
if op.type == Op.FullyConnected:
@@ -156,6 +175,10 @@ class TFLiteSubgraph:
self.graph.nng.subgraphs[cond_subgraph_index],
self.graph.nng.subgraphs[body_subgraph_index],
)
+ if op_type == Op.CallOnce:
+ # Attach the actual nng subgraphs to the op
+ init_subgraph_index = op.attrs["init_subgraph_index"]
+ op.attrs["subgraph"] = (self.graph.nng.subgraphs[init_subgraph_index],)
if op_type == Op.Reshape and "new_shape" not in op.attrs:
# Reshape should have an attrib "new_shape" but if it is missing, add it based on the output shape
@@ -250,6 +273,7 @@ class TFLiteGraph:
sg.name = tflite_sg.name
sg.original_inputs = tflite_sg.inputs # Preserve the original input order
sg.output_tensors = tflite_sg.outputs
+ sg.virtual_outputs = tflite_sg.virtual_outputs
parsing_step = "parsing metadata length"
# Preserve the original metadata