aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_writer.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_writer.py')
-rw-r--r--ethosu/vela/tflite_writer.py13
1 files changed, 12 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index e527cd4d..32982298 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -321,6 +321,10 @@ class TFLiteSerialiser:
attrs["dilation_w_factor"] = attrs["dilation"][2]
if "channel_multiplier" in attrs:
attrs["depth_multiplier"] = attrs["channel_multiplier"]
+ if "container" in attrs:
+ attrs["container"] = builder.CreateString(attrs["container"])
+ if "shared_name" in attrs:
+ attrs["shared_name"] = builder.CreateString(attrs["shared_name"])
attrs["fused_activation_function"] = op.activation.op_type if op.activation is not None else None
builtin_opt_offset, custom_opt_offset = opt_serializer.serialize(builder, attrs)
@@ -362,6 +366,13 @@ class TFLiteSerialiser:
# to an op.
tensor_set = set(sg.original_inputs)
+ # Remove any virtual outputs since they are only used internally when
+ # traversing the graph.
+ for tens in sg.virtual_outputs:
+ tens.ops[0].outputs = []
+ if tens in sg.output_tensors:
+ sg.output_tensors.remove(tens)
+
# Add the tensors from all valid ops, as well as the tensors from placeholder ops
# This allows us to serialise tensors which arent attached to any specific ops,
# e.g. due to an empty graph containing no ops