aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/live_range.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/live_range.py')
-rw-r--r--ethosu/vela/live_range.py32
1 files changed, 19 insertions, 13 deletions
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index e683f9f5..9b6fe63d 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -224,18 +224,24 @@ def extract_live_ranges_from_cascaded_passes(
rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
rng.mark_usage(time_for_pass)
- cps_primary_op = cps.passes[0].primary_op
-
- if (
- cps_primary_op
- and cps_primary_op.type == Op.CustomNpuOp
- and MemType.Permanent_CPU not in target_mem_type_set
- ):
- # If the primary-op is an NpuOp that means this is where an Npu subgraph
- # is called. Go into said subgraph and extract live ranges before continuing.
- # Use default allocation alignment of 16 for Npu tensors
- npu_sg = cps_primary_op.attrs["subgraph"]
- lr_graph = _extract_live_ranges_from_schedule(npu_sg, target_mem_area, target_mem_type_set, lr_graph)
+ op_subgraph = cps.passes[0].ops[0].attrs.get("subgraph", None)
+ op_type = cps.passes[0].ops[0].type
+
+ if op_subgraph is not None and MemType.Permanent_CPU not in target_mem_type_set:
+ if op_type == Op.CustomNpuOp:
+ # If the primary-op is an NpuOp that means this is where an Npu subgraph
+ # is called. Go into said subgraph and extract live ranges before continuing.
+ # Use default allocation alignment of 16 for Npu tensors
+ lr_graph = _extract_live_ranges_from_schedule(
+ op_subgraph, target_mem_area, target_mem_type_set, lr_graph
+ )
+ else:
+ # The op has one or more subgraphs in it (a typical op is the While op)
+ # Go into all subgraphs and extract live ranges before continuing.
+ for op_sg in op_subgraph:
+ lr_graph = extract_live_ranges_from_cascaded_passes(
+ op_sg, target_mem_area, target_mem_type_set, lr_graph, cpu_tensor_alignment
+ )
# Set the new time after handling the Npu subgraph
time_for_pass = lr_graph.current_time
cps.time = time_for_pass