aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/live_range.py29
1 files changed, 13 insertions, 16 deletions
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py
index 3abcfcf0..d64f68e0 100644
--- a/ethosu/vela/live_range.py
+++ b/ethosu/vela/live_range.py
@@ -140,15 +140,16 @@ class LiveRangeGraph:
self.ranges[out_tens] = live_range
return live_range
- def update_endtime(self):
- self.end_time = self.current_time
- return self.end_time + 1
+ def get_endtime(self):
+ # op_length is 1 so max end time for lr is current + 1
+ return self.current_time + 1
def get_temporal_memory_usage(self, target_mem_area):
- usage = np.zeros(self.update_endtime(), dtype=np.int32)
+ usage = np.zeros(self.get_endtime() + 1, dtype=np.int32)
for lr in self.lrs:
if lr.mem_area == target_mem_area:
# End time is inclusive
+ assert lr.end_time <= self.get_endtime() + 1
usage[lr.start_time : lr.end_time + 1] += lr.size
return usage
@@ -268,8 +269,11 @@ def extract_live_ranges_from_cascaded_passes(
op_sg, target_mem_area, target_mem_type_set, lr_graph, cpu_tensor_alignment
)
# Set the new time after handling the Npu subgraph
+ # current time is updated in subgraph path so do not tick the time
time_for_pass = lr_graph.current_time
cps.time = time_for_pass
+ else:
+ lr_graph.current_time += 2
for tens in cps.intermediates + cps.outputs:
if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
@@ -277,23 +281,17 @@ def extract_live_ranges_from_cascaded_passes(
rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
rng.mark_usage(time_for_pass)
- lr_graph.current_time += 2
-
- end_time = 0
- for rng in lr_graph.ranges.values():
- # Find the maximum end time of all live-ranges in the graph
- end_time = max(end_time, rng.end_time)
-
+ time_to_set = lr_graph.current_time
for tens in sg.output_tensors:
if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set):
continue
rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment)
- rng.mark_usage(end_time)
+ rng.mark_usage(time_to_set)
# Variable tensor live-range is for entire inference
for tens, rng in lr_graph.ranges.items():
if tens.is_variable:
- rng.mark_usage(0, end_time + 1)
+ rng.mark_usage(0, time_to_set + 1)
# Add subgraph to set of processed subgraphs
lr_graph.processed_subgraphs.add(sg)
@@ -384,12 +382,11 @@ def extract_live_ranges_from_schedule(sg, target_mem_area, target_mem_type_set,
if cascade != 0:
time_for_cascade[cascade] = time_to_set
- end_time = lr_graph.update_endtime()
-
+ time_to_set = lr_graph.current_time
for tens in sg.output_tensors:
if tens.mem_type not in target_mem_type_set or tens.mem_area != target_mem_area:
continue
rng = lr_graph.get_or_create_range(tens)
- rng.mark_usage(end_time)
+ rng.mark_usage(time_to_set)
return lr_graph