diff options
-rw-r--r-- | ethosu/vela/nn_graph.py | 17 | ||||
-rw-r--r-- | ethosu/vela/stats_writer.py | 6 |
2 files changed, 19 insertions, 4 deletions
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py index 50266d5..92c7e1b 100644 --- a/ethosu/vela/nn_graph.py +++ b/ethosu/vela/nn_graph.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com> +# SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com> # # SPDX-License-Identifier: Apache-2.0 # @@ -318,6 +318,7 @@ class Subgraph: ps.outputs = [out_tens for out_tens in ps.outputs if len(out_tens.consumers()) > 0] ps.ops = [op for op in ps.ops if op.outputs[0] in ps.outputs] + # get_all_ops is used when traversing the original graph def get_all_ops(self): all_ops = [] visit_op_set = set() @@ -344,6 +345,20 @@ class Subgraph: return all_ops + # get_all_ops_from_passes is used by stats writer to calculate the number of + # CPU and NPU ops + # Due to a side effect get_all_ops might not be traversing the full graph + # after extract_npu_subgraph have been called and should not be used by stats writer. + # The reason is that the main graph might have NPU nodes with no visible outputs + # and therefore the nodes will be missed. + def get_all_ops_from_passes(self): + all_ops = [] + for idx, ps in enumerate(self.passes): + for op in ps.ops: + all_ops.append(op) + + return all_ops + def print_operators(self, ignore_placeholder_const=True, show_attributes=True): print(f"Operators of Subgraph {self.name}") diff --git a/ethosu/vela/stats_writer.py b/ethosu/vela/stats_writer.py index 25c9030..b743a5f 100644 --- a/ethosu/vela/stats_writer.py +++ b/ethosu/vela/stats_writer.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> +# SPDX-FileCopyrightText: Copyright 2020-2022, 2024 Arm Limited and/or its affiliates <open-source-office@arm.com> # # SPDX-License-Identifier: Apache-2.0 # @@ -365,11 +365,11 @@ def print_performance_metrics(nng, arch, show_cpu_operations=False, verbose_weig for sg in nng.subgraphs: if sg.placement == PassPlacement.Cpu: - for op in sg.get_all_ops(): + for op in sg.get_all_ops_from_passes(): if op.type not in ir_only_ops: cpu_operations.append(op) elif sg.placement == PassPlacement.Npu: - for op in sg.get_all_ops(): + for op in sg.get_all_ops_from_passes(): if op.type not in ir_only_ops: npu_operations.append(op) |