aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/nn_graph.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/nn_graph.py')
-rw-r--r--ethosu/vela/nn_graph.py17
1 files changed, 16 insertions, 1 deletions
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index 50266d5..92c7e1b 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -318,6 +318,7 @@ class Subgraph:
ps.outputs = [out_tens for out_tens in ps.outputs if len(out_tens.consumers()) > 0]
ps.ops = [op for op in ps.ops if op.outputs[0] in ps.outputs]
+ # get_all_ops is used when traversing the original graph
def get_all_ops(self):
all_ops = []
visit_op_set = set()
@@ -344,6 +345,20 @@ class Subgraph:
return all_ops
+ # get_all_ops_from_passes is used by stats writer to calculate the number of
+ # CPU and NPU ops
+ # Due to a side effect get_all_ops might not be traversing the full graph
+ # after extract_npu_subgraph have been called and should not be used by stats writer.
+ # The reason is that the main graph might have NPU nodes with no visible outputs
+ # and therefore the nodes will be missed.
+ def get_all_ops_from_passes(self):
+ all_ops = []
+ for idx, ps in enumerate(self.passes):
+ for op in ps.ops:
+ all_ops.append(op)
+
+ return all_ops
+
def print_operators(self, ignore_placeholder_const=True, show_attributes=True):
print(f"Operators of Subgraph {self.name}")