diff options
author | Fredrik Svedberg <fredrik.svedberg@arm.com> | 2021-10-27 13:58:03 +0200 |
---|---|---|
committer | Fredrik Svedberg <fredrik.svedberg@arm.com> | 2021-10-27 13:58:03 +0200 |
commit | 0ae2848580d376c66d6a3c3b0fb55b30a0234247 (patch) | |
tree | e36aacb086a50d61a3c413faef3e83f78938ed27 /ethosu/vela/live_range.py | |
parent | 2b5939f639d3ceb9fcf75c2edc78d128676119b3 (diff) | |
download | ethos-u-vela-0ae2848580d376c66d6a3c3b0fb55b30a0234247.tar.gz |
MLBEDSW-5450 MLCE: Vela to handle skip Tensor
Added checks to avoid merging elementwise op live ranges for subgraph
inputs and outputs, which sometimes caused problems when parts of the
network run on CPU.
Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Change-Id: Id07ab277a205b8550d19a276559f8904b9a4b4be
Diffstat (limited to 'ethosu/vela/live_range.py')
-rw-r--r-- | ethosu/vela/live_range.py | 34 |
1 files changed, 12 insertions, 22 deletions
diff --git a/ethosu/vela/live_range.py b/ethosu/vela/live_range.py index 0b94b197..5aec0dfb 100644 --- a/ethosu/vela/live_range.py +++ b/ethosu/vela/live_range.py @@ -112,7 +112,6 @@ class LiveRangeGraph: def __init__(self): self.lrs: List[LiveRange] = [] # List of all created ranges self.ranges = {} # tens -> range - self.ignore_tensors = set() self.processed_subgraphs = set() self.current_time = 0 self.end_time = None @@ -151,17 +150,17 @@ class LiveRangeGraph: return usage -def tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set): +def tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set): if tens.mem_area != target_mem_area or tens.mem_type not in target_mem_type_set: return True - if tens in lr_graph.ignore_tensors: - return True return False -def merge_elementwise_op_ranges(sched_op, lr_graph, target_mem_area, target_mem_type_set): +def merge_elementwise_op_ranges(sg, sched_op, lr_graph, target_mem_area, target_mem_type_set): def _tensor_should_be_ignored(tens): - return tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set) + if tens in sg.input_tensors + sg.output_tensors: + return True + return tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set) # Tries to merge ifm/ofm live ranges of elementwise op if sched_op.op_type.is_elementwise_op(): @@ -198,12 +197,7 @@ def merge_elementwise_op_ranges(sched_op, lr_graph, target_mem_area, target_mem_ def extract_live_ranges_from_cascaded_passes( - sg, - target_mem_area, - target_mem_type_set, - ignore_subgraph_input_output_tensors=False, - lr_graph=None, - cpu_tensor_alignment=Tensor.AllocationQuantum, + sg, target_mem_area, target_mem_type_set, lr_graph=None, cpu_tensor_alignment=Tensor.AllocationQuantum, ): if lr_graph is None: lr_graph = LiveRangeGraph() @@ -212,17 +206,13 @@ def extract_live_ranges_from_cascaded_passes( # if subgraph has been processed already, return the lr_graph as is return lr_graph - if ignore_subgraph_input_output_tensors: - lr_graph.ignore_tensors.update(sg.input_tensors) - lr_graph.ignore_tensors.update(sg.output_tensors) - for cps in sg.cascaded_passes: cps.time = lr_graph.current_time time_for_pass = cps.time for tens in cps.inputs: - if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set): + if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set): continue rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment) rng.mark_usage(time_for_pass) @@ -244,7 +234,7 @@ def extract_live_ranges_from_cascaded_passes( cps.time = time_for_pass for tens in cps.intermediates + cps.outputs: - if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set): + if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set): continue rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment) rng.mark_usage(time_for_pass) @@ -257,7 +247,7 @@ def extract_live_ranges_from_cascaded_passes( end_time = max(end_time, rng.end_time) for tens in sg.output_tensors: - if tensor_should_be_ignored(lr_graph, tens, target_mem_area, target_mem_type_set): + if tensor_should_be_ignored(tens, target_mem_area, target_mem_type_set): continue rng = lr_graph.get_or_create_range(tens, cpu_tensor_alignment) rng.mark_usage(end_time) @@ -273,7 +263,7 @@ def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_ for ps in sg.passes: for tens in ps.inputs + ps.outputs + ps.intermediates: if tens.purpose == TensorPurpose.Weights or tensor_should_be_ignored( - lr_graph, tens, target_mem_area, target_mem_type_set + tens, target_mem_area, target_mem_type_set ): continue rng = lr_graph.get_or_create_range(tens) @@ -281,7 +271,7 @@ def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_ for _, op_info in sg.schedule.cost_map.items(): for tensor in [op_info.npu_weights_tensor, op_info.npu_scales_tensor]: - if tensor and not (tensor_should_be_ignored(lr_graph, tensor, target_mem_area, target_mem_type_set)): + if tensor and not (tensor_should_be_ignored(tensor, target_mem_area, target_mem_type_set)): rng = lr_graph.get_or_create_range(tensor) rng.mark_usage(sg_time) @@ -292,7 +282,7 @@ def create_linear_live_range_graph(sg, target_mem_area, target_mem_type_set, lr_ def _extract_live_ranges_from_schedule(sg, target_mem_area, target_mem_type_set, lr_graph): time_for_cascade = {} for sched_op in sg.sched_ops: - merge_elementwise_op_ranges(sched_op, lr_graph, target_mem_area, target_mem_type_set) + merge_elementwise_op_ranges(sg, sched_op, lr_graph, target_mem_area, target_mem_type_set) op_info = sg.schedule.cost_map[sched_op] cascade = op_info.cascade |