diff options
author | Tim Hall <tim.hall@arm.com> | 2021-05-27 18:49:40 +0100 |
---|---|---|
committer | Tim Hall <tim.hall@arm.com> | 2021-05-27 18:57:39 +0100 |
commit | d8339a75c9b655c0507e34238078fdad068b4023 (patch) | |
tree | 36a14726b30760169a83c0356803b480992fade8 /ethosu/vela/live_range.py | |
parent | 64556f32ff7bfca6036a6598034464b13b64a4ef (diff) | |
download | ethos-u-vela-d8339a75c9b655c0507e34238078fdad068b4023.tar.gz |
MLBEDSW-4034: New Scheduler Size or Performance Optimisation
- Merged dev/scheduler at 83639f90e8c828f70de6e29142355a940224959b
Signed-off-by: Tim Hall <tim.hall@arm.com>
Change-Id: I0050529d4b42da93768c7264296434dd877fb5b4
Diffstat (limited to 'ethosu/vela/live_range.py')
-rw-r--r-- | ethosu/vela/live_range.py | 125 |
1 files changed, 118 insertions, 7 deletions
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py index de001e56..d75a167d 100644 --- a/ethosu/vela/live_range.py +++ b/ethosu/vela/live_range.py @@ -18,10 +18,14 @@ # Can work with either a pass packed subgraph or a scheduled subgraph. from typing import List +import numpy as np + from .nn_graph import PassPlacement from .operation import Op +from .tensor import MemArea from .tensor import MemType from .tensor import Tensor +from .tensor import TensorPurpose class LiveRange: @@ -32,6 +36,7 @@ class LiveRange: self.size = 0 self.name = "" self.alignment = alignment + self.mem_area = tens.mem_area if tens else MemArea.Unknown if tens: self.add_tensor(tens) @@ -52,15 +57,19 @@ class LiveRange: self.tensors.append(tens) - def mark_usage(self, op_time): - if op_time == -1: + def mark_usage(self, op_time, op_length=1): + op_time_start = max(op_time, 0) + op_time_end = op_time + op_length + if op_time_end <= op_time_start: return - op_time_start = op_time - op_time_end = op_time + 1 self.start_time = min(self.start_time, op_time_start) self.end_time = max(self.end_time, op_time_end) + def set_buffer_size(self, buffer_size): + self.size = buffer_size + self.mem_area = MemArea.Sram + def overlaps_ranges(self, other): return max(self.start_time, other.start_time) < min(self.end_time, other.end_time) @@ -106,6 +115,7 @@ class LiveRangeGraph: self.ignore_tensors = set() self.processed_subgraphs = set() self.current_time = 0 + self.end_time = None def get_or_create_range(self, tens, alignment=Tensor.AllocationQuantum): # Return the live range of the tensor (or any of its clones) @@ -127,6 +137,23 @@ class LiveRangeGraph: self.ranges[out_tens] = live_range return live_range + def update_endtime(self): + self.end_time = 0 + for rng in self.ranges.values(): + self.end_time = max(self.end_time, rng.end_time) + return self.end_time + 1 + + def get_temporal_memory_usage(self, target_mem_area): + if not self.end_time: + self.update_endtime() + usage = np.zeros(self.end_time, dtype=np.int32) + for rng in self.ranges.values(): + if rng.mem_area == target_mem_area: + # End time is inclusive + usage[rng.start_time : rng.end_time + 1] += rng.size + + return usage + def tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set): if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set: @@ -279,9 +306,7 @@ def extract_live_ranges_from_cascaded_passes( # is called. Go into said subgraph and extract live ranges before continuing. # Use default allocation alignment of 16 for Npu tensors npu_sg = cps_primary_op.attrs["subgraph"] - lr_graph = extract_live_ranges_from_cascaded_passes( - npu_sg, target_mem_area, target_mem_type_set, False, lr_graph, - ) + lr_graph = _extract_live_ranges_from_schedule(npu_sg, target_mem_area, target_mem_type_set, lr_graph) # Set the new time after handling the Npu subgraph time_for_pass = lr_graph.current_time cps.time = time_for_pass @@ -308,3 +333,89 @@ def extract_live_ranges_from_cascaded_passes( # Add subgraph to set of processed subgraphs lr_graph.processed_subgraphs.add(sg) return lr_graph + + +def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_graph): + assert lr_graph is not None + sg_time = lr_graph.current_time + for ps in sg.passes: + for tens in ps.inputs + ps.outputs + ps.intermediates: + if tens.purpose == TensorPurpose.Weights or tensor_should_be_ignored( + lr_graph, tens, target_mem_area, target_mem_type_set + ): + continue + + rng = lr_graph.get_or_create_range(tens) + rng.mark_usage(sg_time) + + for sched_op, op_info in sg.schedule.cost_map.items(): + if op_info.npu_weights_tensor and not ( + tensor_should_be_ignored(lr_graph, op_info.npu_weights_tensor, target_mem_area, target_mem_type_set) + ): + rng = lr_graph.get_or_create_range(op_info.npu_weights_tensor) + rng.mark_usage(sg_time) + + lr_graph.current_time += 1 + return lr_graph + + +def _extract_live_ranges_from_schedule(sg, target_mem_area, target_mem_type_set, lr_graph): + time_for_cascade = {} + for sched_op in sg.sched_ops: + op_info = sg.schedule.cost_map[sched_op] + cascade = op_info.cascade + cascade_info = sg.schedule.cascades.get(cascade, None) + + time_to_set = time_for_cascade.get(cascade, lr_graph.current_time) + + op_info.time_index = time_to_set + + # Mark usage for all tensors related to this Pass + ps = sched_op.parent_ps + for tens in ps.inputs + ps.outputs + ps.intermediates: + if ( + target_mem_area == MemArea.Sram + and cascade_info + and tens == ps.ifm_tensor + and sched_op in cascade_info.buffers + ): + # This tensor is a rolling buffer in a cascade and the size of the LiveRange needs to be modified + # for enabling temporal memory snapshots without modifying the original Tensor + rng = lr_graph.get_or_create_range(tens) + rng.set_buffer_size(cascade_info.buffers[sched_op].elements() * sched_op.ifm.dtype.size_in_bytes()) + elif ( + tens.purpose == TensorPurpose.Weights + or tens.purpose == TensorPurpose.FSBias + or tens.mem_type not in target_mem_type_set + or tens.mem_area != target_mem_area + ): + continue + + else: + rng = lr_graph.get_or_create_range(tens) + + rng.mark_usage(time_to_set) + + weight_tens = op_info.buffered_weight_tensor + if weight_tens and weight_tens.mem_type in target_mem_type_set and weight_tens.mem_area == target_mem_area: + rng = lr_graph.get_or_create_range(weight_tens) + if weight_tens.pre_buffer: + rng.mark_usage(time_to_set - 1, 2) + else: + rng.mark_usage(time_to_set) + + if time_to_set == lr_graph.current_time: + lr_graph.current_time += 2 + + if cascade != 0: + time_for_cascade[cascade] = time_to_set + + end_time = lr_graph.update_endtime() + + for tens in sg.output_tensors: + if tens.mem_type not in target_mem_type_set or tens.mem_area != target_mem_area: + continue + rng = lr_graph.get_or_create_range(tens) + rng.mark_usage(end_time) + + return lr_graph |