diff options
-rw-r--r-- | ethosu/vela/live_range.py | 34 | ||||
-rw-r--r-- | ethosu/vela/scheduler.py | 4 | ||||
-rw-r--r-- | ethosu/vela/tensor_allocation.py | 8 |
3 files changed, 15 insertions, 31 deletions
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py index 0b94b197..5aec0dfb 100644 --- a/ethosu/vela/live_range.py +++ b/ethosu/vela/live_range.py @@ -112,7 +112,6 @@ class LiveRangeGraph: def __init__(self): self.lrs: List[LiveRange] = [] # List of all created ranges self.ranges = {} # tens -> range - self.ignore_tensors = set() self.processed_subgraphs = set() self.current_time = 0 self.end_time = None @@ -151,17 +150,17 @@ class LiveRangeGraph: return usage -def tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set): +def tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set): if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set: return True - if tens in lr_graph.ignore_tensors: - return True return False -def merge_elementwise_op_ranges(sched_op, lr_graph, target_mem_area, target_mem_type_set): +def merge_elementwise_op_ranges(sg, sched_op, lr_graph, target_mem_area, target_mem_type_set): def _tensor_should_be_ignored(tens): - return tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set) + if tens in sg.input_tensors + sg.output_tensors: + return True + return tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set) # Tries to merge ifm/ofm live ranges of elementwise op if sched_op.op_type.is_elementwise_op(): @@ -198,12 +197,7 @@ def merge_elementwise_op_ranges(sched_op, lr_graph, target_mem_area, target_mem_ def extract_live_ranges_from_cascaded_passes( - sg, - target_mem_area, - target_mem_type_set, - ignore_subgraph_input_output_tensors=False, - lr_graph=None, - cpu_tensor_alignment=Tensor.AllocationQuantum, + sg, target_mem_area, target_mem_type_set, lr_graph=None, cpu_tensor_alignment=Tensor.AllocationQuantum, ): if lr_graph is None: lr_graph = LiveRangeGraph() @@ -212,17 +206,13 @@ def extract_live_ranges_from_cascaded_passes( # if subgraph has been processed already, return the lr_graph as is return lr_graph - if ignore_subgraph_input_output_tensors: - lr_graph.ignore_tensors.update(sg.input_tensors) - lr_graph.ignore_tensors.update(sg.output_tensors) - for cps in sg.cascaded_passes: cps.time = lr_graph.current_time time_for_pass = cps.time for tens in cps.inputs: - if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set): + if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set): continue rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment) rng.mark_usage(time_for_pass) @@ -244,7 +234,7 @@ def extract_live_ranges_from_cascaded_passes( cps.time = time_for_pass for tens in cps.intermediates + cps.outputs: - if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set): + if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set): continue rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment) rng.mark_usage(time_for_pass) @@ -257,7 +247,7 @@ def extract_live_ranges_from_cascaded_passes( end_time = max(end_time, rng.end_time) for tens in sg.output_tensors: - if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set): + if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set): continue rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment) rng.mark_usage(end_time) @@ -273,7 +263,7 @@ def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_ for ps in sg.passes: for tens in ps.inputs + ps.outputs + ps.intermediates: if tens.purpose == TensorPurpose.Weights or tensor_should_be_ignored( - lr_graph, tens, target_mem_area, target_mem_type_set + tens, target_mem_area, target_mem_type_set ): continue rng = lr_graph.get_or_create_range(tens) @@ -281,7 +271,7 @@ def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_ for _, op_info in sg.schedule.cost_map.items(): for tensor in [op_info.npu_weights_tensor, op_info.npu_scales_tensor]: - if tensor and not (tensor_should_be_ignored(lr_graph, tensor, target_mem_area, target_mem_type_set)): + if tensor and not (tensor_should_be_ignored(tensor, target_mem_area, target_mem_type_set)): rng = lr_graph.get_or_create_range(tensor) rng.mark_usage(sg_time) @@ -292,7 +282,7 @@ def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_ def _extract_live_ranges_from_schedule(sg, target_mem_area, target_mem_type_set, lr_graph): time_for_cascade = {} for sched_op in sg.sched_ops: - merge_elementwise_op_ranges(sched_op, lr_graph, target_mem_area, target_mem_type_set) + merge_elementwise_op_ranges(sg, sched_op, lr_graph, target_mem_area, target_mem_type_set) op_info = sg.schedule.cost_map[sched_op] cascade = op_info.cascade diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py index 2ac47878..782e8d98 100644 --- a/ethosu/vela/scheduler.py +++ b/ethosu/vela/scheduler.py @@ -425,7 +425,7 @@ class Scheduler: lr_graph = live_range.LiveRangeGraph() for mem_area, mem_type_set in memories_list: live_range.extract_live_ranges_from_cascaded_passes( - self.nng.get_root_subgraph(), mem_area, mem_type_set, False, lr_graph, Tensor.AllocationQuantum, + self.nng.get_root_subgraph(), mem_area, mem_type_set, lr_graph, Tensor.AllocationQuantum, ) # Populate time-array with memory used by live ranges @@ -918,7 +918,7 @@ class Scheduler: lr_graph = live_range.LiveRangeGraph() for mem_area, mem_type_set in memories_list: live_range.extract_live_ranges_from_cascaded_passes( - self.nng.get_root_subgraph(), mem_area, mem_type_set, False, lr_graph, Tensor.AllocationQuantum, + self.nng.get_root_subgraph(), mem_area, mem_type_set, lr_graph, Tensor.AllocationQuantum, ) # Iterate over live ranges and evict tensors that doesn't fit diff --git a/ethosu/vela/tensor_allocation.py b/ethosu/vela/tensor_allocation.py index 4b5e5e42..db261ae5 100644 --- a/ethosu/vela/tensor_allocation.py +++ b/ethosu/vela/tensor_allocation.py @@ -182,14 +182,8 @@ def allocate( cpu_tensor_alignment=Tensor.AllocationQuantum, ): # Allocates addresses to tensors, returns False if tensors could not be fit within max_size - ignore_subgraph_input_output_tensors = False lrs = live_range.extract_live_ranges_from_cascaded_passes( - sg, - mem_area, - mem_type_set, - ignore_subgraph_input_output_tensors=ignore_subgraph_input_output_tensors, - lr_graph=lr_graph, - cpu_tensor_alignment=cpu_tensor_alignment, + sg, mem_area, mem_type_set, lr_graph=lr_graph, cpu_tensor_alignment=cpu_tensor_alignment, ) total_sz = 0 if lrs.ranges: |