diff options
Diffstat (limited to 'ethosu/vela/nn_graph.py')
-rw-r--r-- | ethosu/vela/nn_graph.py | 17 |
1 files changed, 16 insertions, 1 deletions
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py index 50266d5..92c7e1b 100644 --- a/ethosu/vela/nn_graph.py +++ b/ethosu/vela/nn_graph.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com> +# SPDX-FileCopyrightText: Copyright 2020-2024 Arm Limited and/or its affiliates <open-source-office@arm.com> # # SPDX-License-Identifier: Apache-2.0 # @@ -318,6 +318,7 @@ class Subgraph: ps.outputs = [out_tens for out_tens in ps.outputs if len(out_tens.consumers()) > 0] ps.ops = [op for op in ps.ops if op.outputs[0] in ps.outputs] + # get_all_ops is used when traversing the original graph def get_all_ops(self): all_ops = [] visit_op_set = set() @@ -344,6 +345,20 @@ class Subgraph: return all_ops + # get_all_ops_from_passes is used by stats writer to calculate the number of + # CPU and NPU ops + # Due to a side effect get_all_ops might not be traversing the full graph + # after extract_npu_subgraph have been called and should not be used by stats writer. + # The reason is that the main graph might have NPU nodes with no visible outputs + # and therefore the nodes will be missed. + def get_all_ops_from_passes(self): + all_ops = [] + for idx, ps in enumerate(self.passes): + for op in ps.ops: + all_ops.append(op) + + return all_ops + def print_operators(self, ignore_placeholder_const=True, show_attributes=True): print(f"Operators of Subgraph {self.name}") |