aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/live_range.py2
-rw-r--r--ethosu/vela/mark_tensors.py4
-rw-r--r--ethosu/vela/nn_graph.py6
-rw-r--r--ethosu/vela/pass_packing.py15
-rw-r--r--ethosu/vela/tensor.py11
-rw-r--r--ethosu/vela/tflite_reader.py26
-rw-r--r--ethosu/vela/tflite_writer.py13
7 files changed, 67 insertions, 10 deletions
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index b18afecc..6a2a04ac 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -155,6 +155,8 @@ class LiveRangeGraph:
def tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
+ if tens.purpose == TensorPurpose.Virtual:
+ return True
if target_mem_area is None or target_mem_type_set is None:
return False
if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set:
diff --git a/ethosu/vela/mark_tensors.py b/ethosu/vela/mark_tensors.py
index 64cc7883..4b5bf1dc 100644
--- a/ethosu/vela/mark_tensors.py
+++ b/ethosu/vela/mark_tensors.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -41,6 +41,8 @@ def mark_purpose(tens, arch, purpose):
# Sets tensor's purpose, format, mem_area and mem_type
if tens.purpose == TensorPurpose.Unknown:
tens.purpose = purpose
+ elif tens.purpose == TensorPurpose.Virtual:
+ return
elif tens.purpose not in (purpose, TensorPurpose.LUT):
assert 0, "Cannot resolve tensor purpose {} and {} for tensor {}".format(tens.purpose, purpose, tens)
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index 846632df..a43aac2a 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -149,7 +149,11 @@ class Subgraph:
def __init__(self, name="<unnamed>", placement=PassPlacement.Cpu):
self.output_tensors = []
self.input_tensors = []
- self.original_inputs = [] # Preserve the original input order
+ # Preserve the original input order
+ self.original_inputs = []
+ # Attach virtual outputs to resource variables op
+ # in order to be able to traverse the graph correctly
+ self.virtual_outputs = []
self.passes = []
self.cascaded_passes = []
self.name = name
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index 5c0d8ebe..6049366f 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -469,6 +469,8 @@ def pack_into_passes(nng, arch, verbose_packing=False):
#
# 1) CPU passes that only depends on sg.input_tensor can be
# moved to the top of the list.
+ # ResourceVariables ops like VarHandle, ReadVariable, CallOnce
+ # can also be moved to the top of list.
#
# 2) A CPU pass X is allowed to be grouped together with CPU pass Y
# if there is no NPU pass between pass X and pass Y that depends
@@ -487,17 +489,20 @@ def pack_into_passes(nng, arch, verbose_packing=False):
pass_list_top.insert(0, ps)
continue
- if (
- ps.placement == PassPlacement.Cpu
- and ps.ops[0].ifm in sg.input_tensors
+ if ps.placement == PassPlacement.Cpu and (
+ ps.ops[0].ifm in sg.input_tensors
and (ps.ops[0].ifm2 in sg.input_tensors or ps.ops[0].ifm2 is None)
+ or (ps.ops[0].type in (Op.VarHandle, Op.ReadVariable, Op.CallOnce))
):
- # This CPU pass only depends on sg.input_tensors
+ # This CPU pass only depends on sg.input_tensors or resource variable
pass_list_top.append(ps)
else:
# Add pass to the list that will be sorted in the next step
pass_list.append(ps)
+ # Sort ops by op_index (same call order as in the original graph)
+ pass_list_top = sorted(pass_list_top, key=lambda ps: -1 if ps.ops[0].op_index is None else ps.ops[0].op_index)
+
# Sort the rest of the list based on critera 2.
# Search from bottom of list and when a CPU pass is found
# search forward in the list and see if it is possible to join another CPU pass.
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 6a95bad4..008cd05e 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -114,7 +114,8 @@ class TensorPurpose(enum.IntFlag):
ScratchFast = 4
LUT = 5
FSBias = 6
- Size = 7
+ Virtual = 7
+ Size = 8
def display_name(self) -> str:
return ("Unknown", "Weights", "FeatureMap", "Scratch", "ScratchFast", "LUT", "FastStorageBias", "Size")[
@@ -297,6 +298,14 @@ class QuantizationParameters:
return False
+def create_virtual_tensor(
+ name: str,
+):
+ virtual_tensor = Tensor([], DataType.int8, name)
+ virtual_tensor.purpose = TensorPurpose.Virtual
+ return virtual_tensor
+
+
def create_const_tensor(
name: str,
shape: Shape,
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 80f36457..2325ff65 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -32,6 +32,7 @@ from .reader_util import align_tensor_indices_to_nng
from .reader_util import clone_and_reshape_tensor
from .reader_util import decode_str
from .reader_util import fixup_tensors
+from .tensor import create_virtual_tensor
from .tensor import QuantizationParameters
from .tensor import Tensor
from .tflite.BuiltinOperator import BuiltinOperator
@@ -51,6 +52,7 @@ class TFLiteSubgraph:
for idx in range(subgraph.TensorsLength()):
self.tensors.append(self.parse_tensor(subgraph.Tensors(idx)))
+ self.virtual_outputs = []
for idx in range(subgraph.OperatorsLength()):
self.parse_operator(idx, subgraph.Operators(idx))
@@ -58,6 +60,8 @@ class TFLiteSubgraph:
self.inputs = self.get_tensors_from_indices_remove_duplicates(subgraph.InputsAsNumpy(), "input")
fixup_tensors(self.inputs, self.tensors)
+ self.outputs.extend(self.virtual_outputs)
+
def get_tensors_from_indices_remove_duplicates(self, indices, warning_str):
tensors = []
for idx in indices:
@@ -131,6 +135,21 @@ class TFLiteSubgraph:
for out in op.outputs:
out.ops = [op]
+ if op_type in (Op.AssignVariable, Op.CallOnce):
+ # All graph traversals are based on depth-first and the starting
+ # points are the subgraph output tensors. Because of this, operators
+ # like AssignVariable and CallOnce will not be visit when the
+ # graph is traversed and the ops are never handled. In order to
+ # fix that, the code base will have to be changed in several places.
+ # Until then this workaround is applied. A virtual output is added
+ # both to the operator and to the subgraph. By doing this the full
+ # graph is traversed correctly. The tensor is not used for anything
+ # else.
+ op.name = f"{op_type}_{op_index}"
+ tens = create_virtual_tensor(op.name)
+ op.set_output_tensor(tens)
+ self.virtual_outputs.append(tens)
+
if op.type.is_depthwise_conv2d_op() or op.type.is_conv2d_op() or op.type == Op.FullyConnected:
if inputs[1].values is not None:
if op.type == Op.FullyConnected:
@@ -156,6 +175,10 @@ class TFLiteSubgraph:
self.graph.nng.subgraphs[cond_subgraph_index],
self.graph.nng.subgraphs[body_subgraph_index],
)
+ if op_type == Op.CallOnce:
+ # Attach the actual nng subgraphs to the op
+ init_subgraph_index = op.attrs["init_subgraph_index"]
+ op.attrs["subgraph"] = (self.graph.nng.subgraphs[init_subgraph_index],)
if op_type == Op.Reshape and "new_shape" not in op.attrs:
# Reshape should have an attrib "new_shape" but if it is missing, add it based on the output shape
@@ -250,6 +273,7 @@ class TFLiteGraph:
sg.name = tflite_sg.name
sg.original_inputs = tflite_sg.inputs # Preserve the original input order
sg.output_tensors = tflite_sg.outputs
+ sg.virtual_outputs = tflite_sg.virtual_outputs
parsing_step = "parsing metadata length"
# Preserve the original metadata
diff --git a/ethosu/vela/tflite_writer.py b/ethosu/vela/tflite_writer.py
index e527cd4d..32982298 100644
--- a/ethosu/vela/tflite_writer.py
+++ b/ethosu/vela/tflite_writer.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -321,6 +321,10 @@ class TFLiteSerialiser:
attrs["dilation_w_factor"] = attrs["dilation"][2]
if "channel_multiplier" in attrs:
attrs["depth_multiplier"] = attrs["channel_multiplier"]
+ if "container" in attrs:
+ attrs["container"] = builder.CreateString(attrs["container"])
+ if "shared_name" in attrs:
+ attrs["shared_name"] = builder.CreateString(attrs["shared_name"])
attrs["fused_activation_function"] = op.activation.op_type if op.activation is not None else None
builtin_opt_offset, custom_opt_offset = opt_serializer.serialize(builder, attrs)
@@ -362,6 +366,13 @@ class TFLiteSerialiser:
# to an op.
tensor_set = set(sg.original_inputs)
+ # Remove any virtual outputs since they are only used internally when
+ # traversing the graph.
+ for tens in sg.virtual_outputs:
+ tens.ops[0].outputs = []
+ if tens in sg.output_tensors:
+ sg.output_tensors.remove(tens)
+
# Add the tensors from all valid ops, as well as the tensors from placeholder ops
# This allows us to serialise tensors which arent attached to any specific ops,
# e.g. due to an empty graph containing no ops