diff options
Diffstat (limited to 'ethosu/vela/live_range.py')
-rw-r--r-- | ethosu/vela/live_range.py | 32 |
1 files changed, 19 insertions, 13 deletions
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py index e683f9f5..9b6fe63d 100644 --- a/ethosu/vela/live_range.py +++ b/ethosu/vela/live_range.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 # @@ -224,18 +224,24 @@ def extract_live_ranges_from_cascaded_passes( rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment) rng.mark_usage(time_for_pass) - cps_primary_op = cps.passes[0].primary_op - - if ( - cps_primary_op - and cps_primary_op.type == Op.CustomNpuOp - and MemType.Permanent_CPU not in target_mem_type_set - ): - # If the primary-op is an NpuOp that means this is where an Npu subgraph - # is called. Go into said subgraph and extract live ranges before continuing. - # Use default allocation alignment of 16 for Npu tensors - npu_sg = cps_primary_op.attrs["subgraph"] - lr_graph = _extract_live_ranges_from_schedule(npu_sg, target_mem_area, target_mem_type_set, lr_graph) + op_subgraph = cps.passes[0].ops[0].attrs.get("subgraph", None) + op_type = cps.passes[0].ops[0].type + + if op_subgraph is not None and MemType.Permanent_CPU not in target_mem_type_set: + if op_type == Op.CustomNpuOp: + # If the primary-op is an NpuOp that means this is where an Npu subgraph + # is called. Go into said subgraph and extract live ranges before continuing. + # Use default allocation alignment of 16 for Npu tensors + lr_graph = _extract_live_ranges_from_schedule( + op_subgraph, target_mem_area, target_mem_type_set, lr_graph + ) + else: + # The op has one or more subgraphs in it (a typical op is the While op) + # Go into all subgraphs and extract live ranges before continuing. + for op_sg in op_subgraph: + lr_graph = extract_live_ranges_from_cascaded_passes( + op_sg, target_mem_area, target_mem_type_set, lr_graph, cpu_tensor_alignment + ) # Set the new time after handling the Npu subgraph time_for_pass = lr_graph.current_time cps.time = time_for_pass |