aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_writer.py
diff options
context:
space:
mode:
authorJohan Alfven <johan.alfven@arm.com>2023-02-07 13:01:03 +0100
committerJohan Alfven <johan.alfven@arm.com>2023-02-14 13:29:26 +0100
commit9070f0f1d9ee0fbf2cc3ee62a60f9b600bd62055 (patch)
tree01ad61e2a33c9b976de53656743a24369ccc8119 /ethosu/vela/tflite_writer.py
parent33c01e68984bf455d3a1f00c7f43ab2a6bb75cbe (diff)
downloadethos-u-vela-9070f0f1d9ee0fbf2cc3ee62a60f9b600bd62055.tar.gz
MLBEDSW-7316: Fix crash for networks with resource variables
- The problem was that networks with resource variables have not been thought of. The major problem was the graph traversal where these ops were not visited resulting in an empty subgraph that resulted in the crash. - Fixed the problem by attaching virtual tensors to the ops simulating subgraph output. These tensors are only used to get the graph traversal to work. - Fixed serializing of attribute container and shared_name - Fixed subgraph index for operator CallOnce - All resource variable ops are pushed to the CPU Change-Id: I815f9c81baf7a3fbb686e895980b462f58208b6e Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_writer.py')
-rw-r--r--ethosu/vela/tflite_writer.py13
1 files changed, 12 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index e527cd4d..32982298 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -321,6 +321,10 @@ class TFLiteSerialiser:
attrs["dilation_w_factor"] = attrs["dilation"][2]
if "channel_multiplier" in attrs:
attrs["depth_multiplier"] = attrs["channel_multiplier"]
+ if "container" in attrs:
+ attrs["container"] = builder.CreateString(attrs["container"])
+ if "shared_name" in attrs:
+ attrs["shared_name"] = builder.CreateString(attrs["shared_name"])
attrs["fused_activation_function"] = op.activation.op_type if op.activation is not None else None
builtin_opt_offset, custom_opt_offset = opt_serializer.serialize(builder, attrs)
@@ -362,6 +366,13 @@ class TFLiteSerialiser:
# to an op.
tensor_set = set(sg.original_inputs)
+ # Remove any virtual outputs since they are only used internally when
+ # traversing the graph.
+ for tens in sg.virtual_outputs:
+ tens.ops[0].outputs = []
+ if tens in sg.output_tensors:
+ sg.output_tensors.remove(tens)
+
# Add the tensors from all valid ops, as well as the tensors from placeholder ops
# This allows us to serialise tensors which arent attached to any specific ops,
# e.g. due to an empty graph containing no ops