diff options
-rw-r--r-- | ethosu/vela/live_range.py | 29 |
1 files changed, 13 insertions, 16 deletions
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py index 3abcfcf0..d64f68e0 100644 --- a/ethosu/vela/live_range.py +++ b/ethosu/vela/live_range.py @@ -140,15 +140,16 @@ class LiveRangeGraph: self.ranges[out_tens] = live_range return live_range - def update_endtime(self): - self.end_time = self.current_time - return self.end_time + 1 + def get_endtime(self): + # op_length is 1 so max end time for lr is current + 1 + return self.current_time + 1 def get_temporal_memory_usage(self, target_mem_area): - usage = np.zeros(self.update_endtime(), dtype=np.int32) + usage = np.zeros(self.get_endtime() + 1, dtype=np.int32) for lr in self.lrs: if lr.mem_area == target_mem_area: # End time is inclusive + assert lr.end_time <= self.get_endtime() + 1 usage[lr.start_time : lr.end_time + 1] += lr.size return usage @@ -268,8 +269,11 @@ def extract_live_ranges_from_cascaded_passes( op_sg, target_mem_area, target_mem_type_set, lr_graph, cpu_tensor_alignment ) # Set the new time after handling the Npu subgraph + # current time is updated in subgraph path so do not tick the time time_for_pass = lr_graph.current_time cps.time = time_for_pass + else: + lr_graph.current_time += 2 for tens in cps.intermediates + cps.outputs: if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set): @@ -277,23 +281,17 @@ def extract_live_ranges_from_cascaded_passes( rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment) rng.mark_usage(time_for_pass) - lr_graph.current_time += 2 - - end_time = 0 - for rng in lr_graph.ranges.values(): - # Find the maximum end time of all live-ranges in the graph - end_time = max(end_time, rng.end_time) - + time_to_set = lr_graph.current_time for tens in sg.output_tensors: if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set): continue rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment) - rng.mark_usage(end_time) + rng.mark_usage(time_to_set) # Variable tensor live-range is for entire inference for tens, rng in lr_graph.ranges.items(): if tens.is_variable: - rng.mark_usage(0, end_time + 1) + rng.mark_usage(0, time_to_set + 1) # Add subgraph to set of processed subgraphs lr_graph.processed_subgraphs.add(sg) @@ -384,12 +382,11 @@ def extract_live_ranges_from_schedule(sg, target_mem_area, target_mem_type_set, if cascade != 0: time_for_cascade[cascade] = time_to_set - end_time = lr_graph.update_endtime() - + time_to_set = lr_graph.current_time for tens in sg.output_tensors: if tens.mem_type not in target_mem_type_set or tens.mem_area != target_mem_area: continue rng = lr_graph.get_or_create_range(tens) - rng.mark_usage(end_time) + rng.mark_usage(time_to_set) return lr_graph |