aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/nn_graph.py17
-rw-r--r--ethosu/vela/stats_writer.py6
2 files changed, 19 insertions, 4 deletions
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index 50266d5..92c7e1b 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -318,6 +318,7 @@ class Subgraph:
ps.outputs = [out_tens for out_tens in ps.outputs if len(out_tens.consumers()) > 0]
ps.ops = [op for op in ps.ops if op.outputs[0] in ps.outputs]
+ # get_all_ops is used when traversing the original graph
def get_all_ops(self):
all_ops = []
visit_op_set = set()
@@ -344,6 +345,20 @@ class Subgraph:
return all_ops
+ # get_all_ops_from_passes is used by stats writer to calculate the number of
+ # CPU and NPU ops
+ # Due to a side effect get_all_ops might not be traversing the full graph
+ # after extract_npu_subgraph have been called and should not be used by stats writer.
+ # The reason is that the main graph might have NPU nodes with no visible outputs
+ # and therefore the nodes will be missed.
+ def get_all_ops_from_passes(self):
+ all_ops = []
+ for idx, ps in enumerate(self.passes):
+ for op in ps.ops:
+ all_ops.append(op)
+
+ return all_ops
+
def print_operators(self, ignore_placeholder_const=True, show_attributes=True):
print(f"Operators of Subgraph {self.name}")
diff --git a/ethosu/vela/stats_writer.py b/ethosu/vela/stats_writer.py
index 25c9030..b743a5f 100644
--- a/ethosu/vela/stats_writer.py
+++ b/ethosu/vela/stats_writer.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2022, 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -365,11 +365,11 @@ def print_performance_metrics(nng, arch, show_cpu_operations=False, verbose_weig
for sg in nng.subgraphs:
if sg.placement == PassPlacement.Cpu:
- for op in sg.get_all_ops():
+ for op in sg.get_all_ops_from_passes():
if op.type not in ir_only_ops:
cpu_operations.append(op)
elif sg.placement == PassPlacement.Npu:
- for op in sg.get_all_ops():
+ for op in sg.get_all_ops_from_passes():
if op.type not in ir_only_ops:
npu_operations.append(op)