aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohan Alfven <johan.alfven@arm.com>2023-02-07 13:01:03 +0100
committerJohan Alfven <johan.alfven@arm.com>2023-02-14 13:29:26 +0100
commit9070f0f1d9ee0fbf2cc3ee62a60f9b600bd62055 (patch)
tree01ad61e2a33c9b976de53656743a24369ccc8119
parent33c01e68984bf455d3a1f00c7f43ab2a6bb75cbe (diff)
downloadethos-u-vela-9070f0f1d9ee0fbf2cc3ee62a60f9b600bd62055.tar.gz
MLBEDSW-7316: Fix crash for networks with resource variables
- The problem was that networks with resource variables have not been thought of. The major problem was the graph traversal where these ops were not visited resulting in an empty subgraph that resulted in the crash. - Fixed the problem by attaching virtual tensors to the ops simulating subgraph output. These tensors are only used to get the graph traversal to work. - Fixed serializing of attribute container and shared_name - Fixed subgraph index for operator CallOnce - All resource variable ops are pushed to the CPU Change-Id: I815f9c81baf7a3fbb686e895980b462f58208b6e Signed-off-by: Johan Alfven <johan.alfven@arm.com>
-rw-r--r--ethosu/vela/live_range.py2
-rw-r--r--ethosu/vela/mark_tensors.py4
-rw-r--r--ethosu/vela/nn_graph.py6
-rw-r--r--ethosu/vela/pass_packing.py15
-rw-r--r--ethosu/vela/tensor.py11
-rw-r--r--ethosu/vela/tflite_reader.py26
-rw-r--r--ethosu/vela/tflite_writer.py13
7 files changed, 67 insertions, 10 deletions
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index b18afecc..6a2a04ac 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -155,6 +155,8 @@ class LiveRangeGraph:
def tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
+ if tens.purpose == TensorPurpose.Virtual:
+ return True
if target_mem_area is None or target_mem_type_set is None:
return False
if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set:
diff --git a/ethosu/vela/mark_tensors.py b/ethosu/vela/mark_tensors.py
index 64cc7883..4b5bf1dc 100644
--- a/ethosu/vela/mark_tensors.py
+++ b/ethosu/vela/mark_tensors.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -41,6 +41,8 @@ def mark_purpose(tens, arch, purpose):
# Sets tensor's purpose, format, mem_area and mem_type
if tens.purpose == TensorPurpose.Unknown:
tens.purpose = purpose
+ elif tens.purpose == TensorPurpose.Virtual:
+ return
elif tens.purpose not in (purpose, TensorPurpose.LUT):
assert 0, "Cannot resolve tensor purpose {} and {} for tensor {}".format(tens.purpose, purpose, tens)
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index 846632df..a43aac2a 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -149,7 +149,11 @@ class Subgraph:
def __init__(self, name="<unnamed>", placement=PassPlacement.Cpu):
self.output_tensors = []
self.input_tensors = []
- self.original_inputs = [] # Preserve the original input order
+ # Preserve the original input order
+ self.original_inputs = []
+ # Attach virtual outputs to resource variables op
+ # in order to be able to traverse the graph correctly
+ self.virtual_outputs = []
self.passes = []
self.cascaded_passes = []
self.name = name
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index 5c0d8ebe..6049366f 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -469,6 +469,8 @@ def pack_into_passes(nng, arch, verbose_packing=False):
#
# 1) CPU passes that only depends on sg.input_tensor can be
# moved to the top of the list.
+ # ResourceVariables ops like VarHandle, ReadVariable, CallOnce
+ # can also be moved to the top of list.
#
# 2) A CPU pass X is allowed to be grouped together with CPU pass Y
# if there is no NPU pass between pass X and pass Y that depends
@@ -487,17 +489,20 @@ def pack_into_passes(nng, arch, verbose_packing=False):
pass_list_top.insert(0, ps)
continue
- if (
- ps.placement == PassPlacement.Cpu
- and ps.ops[0].ifm in sg.input_tensors
+ if ps.placement == PassPlacement.Cpu and (
+ ps.ops[0].ifm in sg.input_tensors
and (ps.ops[0].ifm2 in sg.input_tensors or ps.ops[0].ifm2 is None)
+ or (ps.ops[0].type in (Op.VarHandle, Op.ReadVariable, Op.CallOnce))
):
- # This CPU pass only depends on sg.input_tensors
+ # This CPU pass only depends on sg.input_tensors or resource variable
pass_list_top.append(ps)
else:
# Add pass to the list that will be sorted in the next step
pass_list.append(ps)
+ # Sort ops by op_index (same call order as in the original graph)
+ pass_list_top = sorted(pass_list_top, key=lambda ps: -1 if ps.ops[0].op_index is None else ps.ops[0].op_index)
+
# Sort the rest of the list based on critera 2.
# Search from bottom of list and when a CPU pass is found
# search forward in the list and see if it is possible to join another CPU pass.
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 6a95bad4..008cd05e 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -114,7 +114,8 @@ class TensorPurpose(enum.IntFlag):
ScratchFast = 4
LUT = 5
FSBias = 6
- Size = 7
+ Virtual = 7
+ Size = 8
def display_name(self) -> str:
return ("Unknown", "Weights", "FeatureMap", "Scratch", "ScratchFast", "LUT", "FastStorageBias", "Size")[
@@ -297,6 +298,14 @@ class QuantizationParameters:
return False
+def create_virtual_tensor(
+ name: str,
+):
+ virtual_tensor = Tensor([], DataType.int8, name)
+ virtual_tensor.purpose = TensorPurpose.Virtual
+ return virtual_tensor
+
+
def create_const_tensor(
name: str,
shape: Shape,
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 80f36457..2325ff65 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -32,6 +32,7 @@ from .reader_util import align_tensor_indices_to_nng
from .reader_util import clone_and_reshape_tensor
from .reader_util import decode_str
from .reader_util import fixup_tensors
+from .tensor import create_virtual_tensor
from .tensor import QuantizationParameters
from .tensor import Tensor
from .tflite.BuiltinOperator import BuiltinOperator
@@ -51,6 +52,7 @@ class TFLiteSubgraph:
for idx in range(subgraph.TensorsLength()):
self.tensors.append(self.parse_tensor(subgraph.Tensors(idx)))
+ self.virtual_outputs = []
for idx in range(subgraph.OperatorsLength()):
self.parse_operator(idx, subgraph.Operators(idx))
@@ -58,6 +60,8 @@ class TFLiteSubgraph:
self.inputs = self.get_tensors_from_indices_remove_duplicates(subgraph.InputsAsNumpy(), "input")
fixup_tensors(self.inputs, self.tensors)
+ self.outputs.extend(self.virtual_outputs)
+
def get_tensors_from_indices_remove_duplicates(self, indices, warning_str):
tensors = []
for idx in indices:
@@ -131,6 +135,21 @@ class TFLiteSubgraph:
for out in op.outputs:
out.ops = [op]
+ if op_type in (Op.AssignVariable, Op.CallOnce):
+ # All graph traversals are based on depth-first and the starting
+ # points are the subgraph output tensors. Because of this, operators
+ # like AssignVariable and CallOnce will not be visit when the
+ # graph is traversed and the ops are never handled. In order to
+ # fix that, the code base will have to be changed in several places.
+ # Until then this workaround is applied. A virtual output is added
+ # both to the operator and to the subgraph. By doing this the full
+ # graph is traversed correctly. The tensor is not used for anything
+ # else.
+ op.name = f"{op_type}_{op_index}"
+ tens = create_virtual_tensor(op.name)
+ op.set_output_tensor(tens)
+ self.virtual_outputs.append(tens)
+
if op.type.is_depthwise_conv2d_op() or op.type.is_conv2d_op() or op.type == Op.FullyConnected:
if inputs[1].values is not None:
if op.type == Op.FullyConnected:
@@ -156,6 +175,10 @@ class TFLiteSubgraph:
self.graph.nng.subgraphs[cond_subgraph_index],
self.graph.nng.subgraphs[body_subgraph_index],
)
+ if op_type == Op.CallOnce:
+ # Attach the actual nng subgraphs to the op
+ init_subgraph_index = op.attrs["init_subgraph_index"]
+ op.attrs["subgraph"] = (self.graph.nng.subgraphs[init_subgraph_index],)
if op_type == Op.Reshape and "new_shape" not in op.attrs:
# Reshape should have an attrib "new_shape" but if it is missing, add it based on the output shape
@@ -250,6 +273,7 @@ class TFLiteGraph:
sg.name = tflite_sg.name
sg.original_inputs = tflite_sg.inputs # Preserve the original input order
sg.output_tensors = tflite_sg.outputs
+ sg.virtual_outputs = tflite_sg.virtual_outputs
parsing_step = "parsing metadata length"
# Preserve the original metadata
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index e527cd4d..32982298 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -321,6 +321,10 @@ class TFLiteSerialiser:
attrs["dilation_w_factor"] = attrs["dilation"][2]
if "channel_multiplier" in attrs:
attrs["depth_multiplier"] = attrs["channel_multiplier"]
+ if "container" in attrs:
+ attrs["container"] = builder.CreateString(attrs["container"])
+ if "shared_name" in attrs:
+ attrs["shared_name"] = builder.CreateString(attrs["shared_name"])
attrs["fused_activation_function"] = op.activation.op_type if op.activation is not None else None
builtin_opt_offset, custom_opt_offset = opt_serializer.serialize(builder, attrs)
@@ -362,6 +366,13 @@ class TFLiteSerialiser:
# to an op.
tensor_set = set(sg.original_inputs)
+ # Remove any virtual outputs since they are only used internally when
+ # traversing the graph.
+ for tens in sg.virtual_outputs:
+ tens.ops[0].outputs = []
+ if tens in sg.output_tensors:
+ sg.output_tensors.remove(tens)
+
# Add the tensors from all valid ops, as well as the tensors from placeholder ops
# This allows us to serialise tensors which arent attached to any specific ops,
# e.g. due to an empty graph containing no ops